You need to sign in or sign up before continuing.
Unverified Commit e00dfd95 authored by Achal Dixit's avatar Achal Dixit Committed by GitHub
Browse files

[test] Added disable_checkpointing unit test (#779)

* [test] Added disable_checkpointing unit test

* [test] Added disable_checkpointing unit test (Clean-up)

* [test] Added disable_checkpointing unit test (Clean-up)
parent 4f7f0853
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda from fairscale.utils.testing import skip_if_no_cuda
...@@ -303,3 +303,39 @@ def test_list_input(): ...@@ -303,3 +303,39 @@ def test_list_input():
# Backward. Adds 1 more forward call due to checkpoint. # Backward. Adds 1 more forward call due to checkpoint.
loss.backward() loss.backward()
assert count == 3, f"Incorrect count {count}" assert count == 3, f"Incorrect count {count}"
def test_checkpoint_disabling():
""" Test to check new disable_checkpoint() API added to checkpoint_wrapper."""
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.cnt = 0
self.linear = nn.Linear(2, 2)
def forward(self, x):
self.cnt += 1
y = []
for i in x:
y.append(self.linear(i))
return y
x = torch.rand(4, 2)
model1 = checkpoint_wrapper(TestModel())
model2 = checkpoint_wrapper(TestModel())
# Forward. cnt += 1
y = model1(x)
y = sum(i.sum() for i in y)
# Backward. cnt += 1
y.backward()
assert model1.cnt == 2
with disable_checkpointing():
# Forward. cnt += 1
y = model2(x)
y = sum(i.sum() for i in y)
# Backward. cnt remains same as checkpointing is disabled
y.backward()
assert model2.cnt == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment