Unverified Commit a25a14f2 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix A Bug Related to GroupRevRes (#4181)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 15188611
......@@ -72,7 +72,8 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs))
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
......
......@@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype):
h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h)
result = model(g, h)
result.sum().backward()
@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment