Unverified Commit c558be6b authored by Ross Wightman's avatar Ross Wightman Committed by GitHub
Browse files

Fix #2221, DenseNet issue with gradient checkpoints (#2236)

* Fix #2221, DenseNet issue with gradient checkpoints (memory_efficient=True)

* Add grad/param count test for mem_efficient densenet
parent 12b551e7
......@@ -201,9 +201,11 @@ class ModelTester(TestCase):
for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']:
model1 = models.__dict__[name](num_classes=50, memory_efficient=True)
params = model1.state_dict()
num_params = sum([x.numel() for x in model1.parameters()])
model1.eval()
out1 = model1(x)
out1.sum().backward()
num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None])
model2 = models.__dict__[name](num_classes=50, memory_efficient=False)
model2.load_state_dict(params)
......@@ -212,6 +214,7 @@ class ModelTester(TestCase):
max_diff = (out1 - out2).abs().max()
self.assertTrue(num_params == num_grad)
self.assertTrue(max_diff < 1e-5)
def test_resnet_dilation(self):
......
......@@ -53,9 +53,9 @@ class _DenseLayer(nn.Module):
def call_checkpoint_bottleneck(self, input):
# type: (List[Tensor]) -> Tensor
def closure(*inputs):
return self.bn_function(*inputs)
return self.bn_function(inputs)
return cp.checkpoint(closure, input)
return cp.checkpoint(closure, *input)
@torch.jit._overload_method # noqa: F811
def forward(self, input):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment