Unverified Commit d45eafd4 authored by Venzino.Han's avatar Venzino.Han Committed by GitHub
Browse files

[NN] add gatconv edge_weight (#5503)


Co-authored-by: default avatarvenzino-han <venzino.han@buzzvil.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 2b5921e7
......@@ -227,7 +227,7 @@ class GATConv(nn.Module):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, get_attention=False):
def forward(self, graph, feat, edge_weight=None, get_attention=False):
r"""
Description
......@@ -243,6 +243,8 @@ class GATConv(nn.Module):
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
A 1D tensor of edge weight values. Shape: :math:`(|E|,)`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
......@@ -327,6 +329,10 @@ class GATConv(nn.Module):
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
if edge_weight is not None:
graph.edata["a"] = graph.edata["a"] * edge_weight.tile(
1, self._num_heads, 1
).transpose(0, 2)
# message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
......
......@@ -540,8 +540,8 @@ def test_rgcn_default_nbasis(idtype, O):
@pytest.mark.parametrize("out_dim", [1, 5])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gat = nn.GATConv(5, out_dim, num_heads)
feat = F.randn((g.number_of_src_nodes(), 5))
gat = gat.to(ctx)
......@@ -565,8 +565,8 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gat = nn.GATConv(5, out_dim, num_heads)
feat = (
F.randn((g.number_of_src_nodes(), 5)),
......@@ -579,6 +579,27 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
assert a.shape == (g.num_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_edge_weight(g, idtype, out_dim, num_heads):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gat = nn.GATConv(5, out_dim, num_heads)
feat = (
F.randn((g.number_of_src_nodes(), 5)),
F.randn((g.number_of_dst_nodes(), 5)),
)
gat = gat.to(ctx)
ew = F.randn((g.num_edges(),))
h = gat(g, feat, edge_weight=ew)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True)
assert a.shape[0] == ew.shape[0]
assert a.shape == (g.num_edges(), num_heads, 1)
@parametrize_idtype
@pytest.mark.parametrize(
"g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment