Unverified Commit c8bc3e62 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #132 from NVIDIA/testing_cache_fix

Fix + tests for the eval->training caching issue
parents 06e11bd3 a88c09cf
...@@ -86,13 +86,30 @@ def cached_cast(cast_fn, x, cache): ...@@ -86,13 +86,30 @@ def cached_cast(cast_fn, x, cache):
return type(x)([cached_cast(y) for y in x]) return type(x)([cached_cast(y) for y in x])
if x in cache: if x in cache:
cached_x = cache[x] cached_x = cache[x]
# During eval, it's possible to end up caching casted weights if x.requires_grad and cached_x.requires_grad:
# with requires_grad == False. This is then a problem when they # Make sure x is actually cached_x's autograd parent.
# get reused on the next train iter. So we ensure that cached if cached_x.grad_fn.next_functions[1][0].variable is not x:
# weights have same requires_grad flag of most recent request. raise RuntimeError("x and cache[x] both require grad, but x is not "
if x.requires_grad != cached_x.requires_grad: "cache[x]'s parent. This is likely an error.")
cached_x.requires_grad_(x.requires_grad) # During eval, it's possible to end up caching casted weights with
return cache[x] # requires_grad=False. On the next training iter, if cached_x is found
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x]
else:
return cached_x
casted_x = cast_fn(x) casted_x = cast_fn(x)
cache[x] = casted_x cache[x] = casted_x
......
...@@ -292,7 +292,8 @@ class DistributedDataParallel(Module): ...@@ -292,7 +292,8 @@ class DistributedDataParallel(Module):
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket != num_buckets. " raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
self.next_bucket, self.num_buckets),
"This probably indicates some buckets were not allreduced.") "This probably indicates some buckets were not allreduced.")
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
def get_reference_grad(i, w, ops):
# Creating new tensors ensures, among other things, that the new tensors are not in the cache.
# In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
fp32_i = i.detach().clone().float()
fp32_w = w.detach().clone().float().requires_grad_()
loss = ops(fp32_i, fp32_w)
loss.backward()
return fp32_w.grad
class WhitelistModule(torch.nn.Module):
def __init__(self, dtype):
super(WhitelistModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))
@staticmethod
def ops(input, weight):
return (input.mm(weight)).mm(weight).sum()
def forward(self, input):
return self.ops(input, self.weight)
class BlacklistModule(torch.nn.Module):
def __init__(self, dtype):
super(BlacklistModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
@staticmethod
def ops(input, weight):
return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()
def forward(self, input):
return self.ops(input, self.weight)
class PromoteModule(torch.nn.Module):
def __init__(self, dtype):
super(PromoteModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
@staticmethod
def ops(input, weight):
return ((input*weight)*weight).sum()
def forward(self, input):
return self.ops(input, self.weight)
class TestCache(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def train_eval_train_test(self, module, t):
model = module(t).cuda()
dummy_optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
def training_step():
for param in model.parameters():
param.grad = None
loss = model(self.x).sum()
self.handle._default_scaler._loss_scale = 1.0
with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss:
scaled_loss.backward()
self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
self.assertEqual(model.weight.grad.type(), model.weight.type())
reference_grad = get_reference_grad(self.x, model.weight, model.ops)
# Currently there's no difference in the allclose calls, so no need for branching,
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if model.weight.grad.type() == "torch.cuda.HalfTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
elif model.weight.grad.type() == "torch.cuda.FloatTensor":
self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))
else:
raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))
model.weight.data -= 1.
# Simulates first epoch
training_step()
# Simulates eval
with torch.no_grad():
loss = model(self.x).sum()
# Simulates resuming training after eval
training_step()
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def test_whitelist_module_fp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float16)
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32)
def test_blacklist_module_fp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float16)
def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32)
def test_promote_module_fp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.float16)
def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment