"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "19fe95ac44f02c9657ff465ac4f9565804087928"
Unverified Commit a25a14f2 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix A Bug Related to GroupRevRes (#4181)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 15188611
...@@ -72,7 +72,8 @@ class InvertibleCheckpoint(torch.autograd.Function): ...@@ -72,7 +72,8 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs) detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs) temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs)) filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,), gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights, inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs) grad_outputs=grad_outputs)
......
...@@ -508,14 +508,14 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -508,14 +508,14 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
efeat = F.randn((g.number_of_edges(), 5)) efeat = F.randn((g.number_of_edges(), 5))
egat = egat.to(ctx) egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat) h, f = egat(g, nfeat, efeat)
th.save(egat, tmp_buffer) th.save(egat, tmp_buffer)
assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats) assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats) assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True) _, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1) assert attn.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5]) @pytest.mark.parametrize('out_node_feats', [1, 5])
...@@ -533,7 +533,7 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -533,7 +533,7 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
efeat = F.randn((g.number_of_edges(), 7)) efeat = F.randn((g.number_of_edges(), 7))
egat = egat.to(ctx) egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat) h, f = egat(g, nfeat, efeat)
th.save(egat, tmp_buffer) th.save(egat, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
...@@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype): ...@@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype):
h = th.randn(num_nodes, feats).to(dev) h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups) conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev) model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h) result = model(g, h)
result.sum().backward()
@pytest.mark.parametrize('in_size', [16, 32]) @pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32]) @pytest.mark.parametrize('hidden_size', [16, 32])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment