"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b8f905f18b29cdcaba2a4758ac29bf44b9b7dff9"
Unverified Commit f1b0a079 authored by Venzino.Han's avatar Venzino.Han Committed by GitHub
Browse files

[NN] add egatconv edge_weight (#5539)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 7cd6257f
...@@ -159,7 +159,9 @@ class EGATConv(nn.Module): ...@@ -159,7 +159,9 @@ class EGATConv(nn.Module):
init.xavier_normal_(self.attn, gain=gain) init.xavier_normal_(self.attn, gain=gain)
init.constant_(self.bias, 0) init.constant_(self.bias, 0)
def forward(self, graph, nfeats, efeats, get_attention=False): def forward(
self, graph, nfeats, efeats, edge_weight=None, get_attention=False
):
r""" r"""
Compute new node and edge features. Compute new node and edge features.
...@@ -180,6 +182,8 @@ class EGATConv(nn.Module): ...@@ -180,6 +182,8 @@ class EGATConv(nn.Module):
where: where:
:math:`F_{in}` is size of input node feature, :math:`F_{in}` is size of input node feature,
:math:`E` is the number of edges. :math:`E` is the number of edges.
edge_weight : torch.Tensor, optional
A 1D tensor of edge weight values. Shape: :math:`(|E|,)`.
get_attention : bool, optional get_attention : bool, optional
Whether to return the attention values. Default to False. Whether to return the attention values. Default to False.
...@@ -235,6 +239,10 @@ class EGATConv(nn.Module): ...@@ -235,6 +239,10 @@ class EGATConv(nn.Module):
# compute attention factor # compute attention factor
e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1) e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
graph.edata["a"] = edge_softmax(graph, e) graph.edata["a"] = edge_softmax(graph, e)
if edge_weight is not None:
graph.edata["a"] = graph.edata["a"] * edge_weight.tile(
1, self._num_heads, 1
).transpose(0, 2)
graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view( graph.srcdata["h_out"] = self.fc_node_src(nfeats_src).view(
-1, self._num_heads, self._out_node_feats -1, self._num_heads, self._out_node_feats
) )
......
...@@ -652,8 +652,8 @@ def test_gatv2_conv_bi(g, idtype, out_dim, num_heads): ...@@ -652,8 +652,8 @@ def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
@pytest.mark.parametrize("out_edge_feats", [1, 5]) @pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4]) @pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv( egat = nn.EGATConv(
in_node_feats=10, in_node_feats=10,
in_edge_feats=5, in_edge_feats=5,
...@@ -670,7 +670,7 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -670,7 +670,7 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
assert h.shape == (g.num_nodes(), num_heads, out_node_feats) assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats) assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True) _, _, attn = egat(g, nfeat, efeat, get_attention=True)
assert attn.shape == (g.num_edges(), num_heads, 1) assert attn.shape == (g.num_edges(), num_heads, 1)
...@@ -680,8 +680,8 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -680,8 +680,8 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
@pytest.mark.parametrize("out_edge_feats", [1, 5]) @pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4]) @pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads): def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv( egat = nn.EGATConv(
in_node_feats=(10, 15), in_node_feats=(10, 15),
in_edge_feats=7, in_edge_feats=7,
...@@ -701,7 +701,36 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -701,7 +701,36 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats) assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True) _, _, attn = egat(g, nfeat, efeat, get_attention=True)
assert attn.shape == (g.num_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_node_feats", [1, 5])
@pytest.mark.parametrize("out_edge_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_egat_conv_edge_weight(
g, idtype, out_node_feats, out_edge_feats, num_heads
):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
egat = nn.EGATConv(
in_node_feats=10,
in_edge_feats=5,
out_node_feats=out_node_feats,
out_edge_feats=out_edge_feats,
num_heads=num_heads,
)
egat = egat.to(ctx)
nfeat = F.randn((g.num_nodes(), 10))
efeat = F.randn((g.num_edges(), 5))
ew = F.randn((g.num_edges(),))
h, f, attn = egat(g, nfeat, efeat, edge_weight=ew, get_attention=True)
assert h.shape == (g.num_nodes(), num_heads, out_node_feats)
assert f.shape == (g.num_edges(), num_heads, out_edge_feats)
assert attn.shape == (g.num_edges(), num_heads, 1) assert attn.shape == (g.num_edges(), num_heads, 1)
...@@ -710,8 +739,8 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads): ...@@ -710,8 +739,8 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
@pytest.mark.parametrize("out_feats", [1, 5]) @pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4]) @pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv(g, idtype, out_feats, num_heads): def test_edgegat_conv(g, idtype, out_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx)
edgegat = nn.EdgeGATConv( edgegat = nn.EdgeGATConv(
in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads in_feats=10, edge_feats=5, out_feats=out_feats, num_heads=num_heads
) )
...@@ -732,8 +761,8 @@ def test_edgegat_conv(g, idtype, out_feats, num_heads): ...@@ -732,8 +761,8 @@ def test_edgegat_conv(g, idtype, out_feats, num_heads):
@pytest.mark.parametrize("out_feats", [1, 5]) @pytest.mark.parametrize("out_feats", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4]) @pytest.mark.parametrize("num_heads", [1, 4])
def test_edgegat_conv_bi(g, idtype, out_feats, num_heads): def test_edgegat_conv_bi(g, idtype, out_feats, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx)
edgegat = nn.EdgeGATConv( edgegat = nn.EdgeGATConv(
in_feats=(10, 15), in_feats=(10, 15),
edge_feats=7, edge_feats=7,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment