"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d1d0b8afce2fb07134a243aa42136424a3a632a2"
Unverified Commit f1b0a079 authored by Venzino.Han's avatar Venzino.Han Committed by GitHub
Browse files

[NN] add egatconv edge_weight (#5539)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 7cd6257f
......@@ -159,7 +159,9 @@ class EGATConv(nn.Module):
init.xavier_normal_(self.attn, gain=gain)
init.constant_(self.bias, 0)
def forward(self, graph, nfeats, efeats, get_attention=False):
def forward(
self, graph, nfeats, efeats, edge_weight=None, get_attention=False
):
r"""
Compute new node and edge features.
......@@ -180,6 +182,8 @@ class EGATConv(nn.Module):
where:
:math:`F_{in}` is size of input node feature,
:math:`E` is the number of edges.
edge_weight : torch.Tensor, optional
A 1D tensor of edge weight values. Shape: :math:`(|E|,)`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
......@@ -235,6 +239,10 @@ class EGATConv(nn.Module):
# compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata["a"] = edge_softmax(graph, e)
if edge_weight is not None:
graph.edata["a"] = graph.edata["a"] * edge_weight.tile(
1, self._num_heads, 1
).transpose(0, 2)
graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view(
-1, self._num_heads, self._out_node_feats
)
......
......@@ -652,8 +652,8 @@ def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv(
in_node_feats=10,
in_edge_feats=5,
......@@ -670,7 +670,7 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
_, _, attn = egat(g, nfeat, efeat, get_attention=True)
assert attn.shape == (g.num_edges(), num_heads, 1)
......@@ -680,8 +680,8 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv(
in_node_feats=(10, 15),
in_edge_feats=7,
......@@ -701,7 +701,36 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
_, _, attn = egat(g, nfeat, efeat, get_attention=True)
assert attn.shape == (g.num_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
g, idtype, out_node_feats, out_edge_feats, num_heads
):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv(
in_node_feats=10,
in_edge_feats=5,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads,
)
egat = egat.to(ctx)
nfeat = F.randn((g.num_nodes(), 10))
efeat = F.randn((g.num_edges(), 5))
ew = F.randn((g.num_edges(),))
h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)
assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
assert attn.shape == (g.num_edges(), num_heads, 1)
......@@ -710,8 +739,8 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
edgegat = nn.EdgeGATConv(
in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
)
......@@ -732,8 +761,8 @@ def test_edgegat_conv(g, idtype, out_feats, num_heads):
@pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
edgegat = nn.EdgeGATConv(
in_feats=(10, 15),
edge_feats=7,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment