"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0855d255013e067fb59af9466bdd2dab683c5f19"
Unverified Commit a25a14f2 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix A Bug Related to GroupRevRes (#4181)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 15188611
...@@ -72,7 +72,8 @@ class InvertibleCheckpoint(torch.autograd.Function): ...@@ -72,7 +72,8 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs) detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs) temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs)) filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,), gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights, inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs) grad_outputs=grad_outputs)
......
...@@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype): ...@@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype):
h = th.randn(num_nodes, feats).to(dev) h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups) conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev) model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h) result = model(g, h)
result.sum().backward()
@pytest.mark.parametrize('in_size', [16, 32]) @pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32]) @pytest.mark.parametrize('hidden_size', [16, 32])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment