Unverified Commit bcb5be4a authored by caojy1998's avatar caojy1998 Committed by GitHub
Browse files

[Bug fix] Fix the bug in creating unibipartite heterogenous graph (#6093)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-37.ap-northeast-1.compute.internal>
parent 34a2407c
...@@ -211,9 +211,10 @@ def data_dict_to_list(graph, data_dict, func, target): ...@@ -211,9 +211,10 @@ def data_dict_to_list(graph, data_dict, func, target):
------------- -------------
graph : DGLGraph graph : DGLGraph
The input graph. The input graph.
data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]] data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]] or Tensor
Node or edge data stored in DGLGraph. The key of the dictionary Node or edge data stored in DGLGraph. The key of the dictionary
is the node type name or edge type name. is the node type name or edge type name. If there is only single source
node type, data_dict is the value of feature(a Tensor) not a dict.
func : dgl.function.BaseMessageFunction func : dgl.function.BaseMessageFunction
Built-in message function. Built-in message function.
target : 'u', 'v' or 'e' target : 'u', 'v' or 'e'
...@@ -228,13 +229,22 @@ def data_dict_to_list(graph, data_dict, func, target): ...@@ -228,13 +229,22 @@ def data_dict_to_list(graph, data_dict, func, target):
if isinstance(func, fn.BinaryMessageFunction): if isinstance(func, fn.BinaryMessageFunction):
if target in ["u", "v"]: if target in ["u", "v"]:
output_list = [None] * graph._graph.number_of_ntypes() output_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, dsttype in graph.canonical_etypes: # If there is only single source node type, data_dict should be the value of
# feature, namely, a tensor.
if not isinstance(data_dict, dict):
src_id, dst_id = graph._graph.metagraph.find_edge(0)
if target == "u": if target == "u":
src_id = graph.get_ntype_id(srctype) output_list[src_id] = data_dict
output_list[src_id] = data_dict[srctype]
else: else:
dst_id = graph.get_ntype_id(dsttype) output_list[dst_id] = data_dict
output_list[dst_id] = data_dict[dsttype] else:
for srctype, _, dsttype in graph.canonical_etypes:
if target == "u":
src_id = graph.get_ntype_id(srctype)
output_list[src_id] = data_dict[srctype]
else:
dst_id = graph.get_ntype_id(dsttype)
output_list[dst_id] = data_dict[dsttype]
else: # target == 'e' else: # target == 'e'
output_list = [None] * graph._graph.number_of_etypes() output_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
......
...@@ -10,7 +10,9 @@ import dgl.function as fn ...@@ -10,7 +10,9 @@ import dgl.function as fn
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
import scipy.sparse as ssp import scipy.sparse as spsp
import torch
from dgl import DGLError from dgl import DGLError
from scipy.sparse import rand from scipy.sparse import rand
from utils import get_cases, parametrize_idtype from utils import get_cases, parametrize_idtype
...@@ -47,6 +49,23 @@ def create_test_heterograph(idtype): ...@@ -47,6 +49,23 @@ def create_test_heterograph(idtype):
return g return g
def create_random_hetero_with_single_source_node_type(idtype):
num_nodes = {"n1": 5, "n2": 10, "n3": 15}
etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n1", "r3", "n2")]
edges = {}
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(
num_nodes[src_ntype],
num_nodes[dst_ntype],
density=1,
format="coo",
random_state=100,
)
edges[etype] = (arr.row, arr.col)
return dgl.heterograph(edges, idtype=idtype, device=F.ctx())
@parametrize_idtype @parametrize_idtype
def test_unary_copy_u(idtype): def test_unary_copy_u(idtype):
def _test(mfunc): def _test(mfunc):
...@@ -260,6 +279,23 @@ def test_binary_op(idtype): ...@@ -260,6 +279,23 @@ def test_binary_op(idtype):
_test(lhs, rhs, binary_op) _test(lhs, rhs, binary_op)
# Here we test heterograph with only single source node type because the format
# of node feature is a tensor.
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
def test_heterograph_with_single_source_node_type_apply_edges(idtype):
hg = create_random_hetero_with_single_source_node_type(idtype)
hg.nodes["n1"].data["h"] = F.randn((hg.num_nodes("n1"), 1))
hg.nodes["n2"].data["h"] = F.randn((hg.num_nodes("n2"), 1))
hg.nodes["n3"].data["h"] = F.randn((hg.num_nodes("n3"), 1))
assert type(hg.srcdata["h"]) == torch.Tensor
hg.apply_edges(fn.u_add_v("h", "h", "x"))
if __name__ == "__main__": if __name__ == "__main__":
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment