"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "4af3f8bcc859276d2e40d2dad69fdec6bf373f72"
Unverified Commit 7380d61e authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[NN] Support Unidirectional Bipartite Graphs in CFConv (#2674)



* Update

* update

* Update
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 97bdae9e
...@@ -123,15 +123,23 @@ class CFConv(nn.Module): ...@@ -123,15 +123,23 @@ class CFConv(nn.Module):
---------- ----------
g : DGLGraph g : DGLGraph
The graph. The graph.
node_feats : float32 tensor of shape (V, node_in_feats) node_feats : torch.Tensor or pair of torch.Tensor
Input node features, V for the number of nodes. The input node features. If a torch.Tensor is given, it represents the input
edge_feats : float32 tensor of shape (E, edge_in_feats) node feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of
Input edge features, E for the number of edges. input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, which is the case for bipartite graph,
the pair must contain two tensors of shape :math:`(N_{src}, D_{in_{src}})` and
:math:`(N_{dst}, D_{in_{dst}})` separately for the source and destination nodes.
edge_feats : torch.Tensor
The input edge feature of shape :math:`(E, edge_in_feats)`
where :math:`E` is the number of edges.
Returns Returns
------- -------
float32 tensor of shape (V, out_feats) torch.Tensor
Updated node representations. The output node feature of shape :math:`(N_{out}, out_feats)`
where :math:`N_{out}` is the number of destination nodes.
""" """
with g.local_scope(): with g.local_scope():
if isinstance(node_feats, tuple): if isinstance(node_feats, tuple):
......
...@@ -923,7 +923,7 @@ def test_atomic_conv(g, idtype): ...@@ -923,7 +923,7 @@ def test_atomic_conv(g, idtype):
assert h.shape[-1] == 4 assert h.shape[-1] == 4
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_dim', [1, 3]) @pytest.mark.parametrize('out_dim', [1, 3])
def test_cf_conv(g, idtype, out_dim): def test_cf_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
...@@ -936,9 +936,15 @@ def test_cf_conv(g, idtype, out_dim): ...@@ -936,9 +936,15 @@ def test_cf_conv(g, idtype, out_dim):
if F.gpu_ctx(): if F.gpu_ctx():
cfconv = cfconv.to(ctx) cfconv = cfconv.to(ctx)
node_feats = F.randn((g.number_of_nodes(), 2)) src_feats = F.randn((g.number_of_src_nodes(), 2))
edge_feats = F.randn((g.number_of_edges(), 3)) edge_feats = F.randn((g.number_of_edges(), 3))
h = cfconv(g, node_feats, edge_feats) h = cfconv(g, src_feats, edge_feats)
# current we only do shape check
assert h.shape[-1] == out_dim
# case for bipartite graphs
dst_feats = F.randn((g.number_of_dst_nodes(), 3))
h = cfconv(g, (src_feats, dst_feats), edge_feats)
# current we only do shape check # current we only do shape check
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment