Unverified Commit 2478a9ad authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] checkpoint: multiple input and output model test (#425)

parent 3b0717eb
......@@ -195,3 +195,60 @@ def test_offload_memory():
# Use print to collect all debugging info.
print(base, cpt, offload)
assert 0
class MultiinMultioutModel(nn.Module):
"""Model used to check different inputs and outputs"""
def __init__(self, multiout=False, checkpoint_config=0):
super().__init__()
torch.manual_seed(0) # make sure weights are deterministic.
self.multiout = multiout
self.conv1 = nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))
self.conv2 = nn.Sequential(nn.Conv2d(3, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))
assert 0 <= checkpoint_config <= 3
if checkpoint_config & 1:
self.conv1 = checkpoint_wrapper(self.conv1)
if checkpoint_config & (1 << 1):
self.conv2 = checkpoint_wrapper(self.conv2)
def forward(self, x1, x2=None):
out1 = self.conv1(x1)
out2 = self.conv2(x2)
if self.multiout:
return out1, out2
return out1 + out2
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("multiout", [True, False])
@pytest.mark.parametrize("checkpoint_config", [1, 2, 3])
def test_multiin_multiout(device, multiout, checkpoint_config):
if "cuda" in device and not torch.cuda.is_available():
pytest.skip("test requires a GPU")
def train(model, in1, in2):
out = model(in1, x2=in2)
if isinstance(out, tuple):
out = torch.cat(out)
loss = out.sum()
loss.backward()
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
return {"loss": loss.item(), "gnorm": gnorm.item()}
in1 = torch.rand(4, 1, 32, 32).requires_grad_(True)
in2 = torch.rand(4, 3, 32, 32).requires_grad_(True)
model = MultiinMultioutModel(multiout, 0).to(device)
no_cpt = train(model, in1.to(device), in2.to(device))
model = MultiinMultioutModel(multiout, checkpoint_config).to(device)
cpt = train(model, in1.to(device), in2.to(device))
for key in ["loss", "gnorm"]:
if no_cpt[key] != cpt[key]:
print(no_cpt, cpt)
assert 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment