Unverified Commit 17141dd3 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Fix] Fix to_homo for graph with zero nodes ntype (#3011)

* fix #2870

* lint

* fix
parent 8b64ae59
...@@ -5985,6 +5985,7 @@ def combine_frames(frames, ids, col_names=None): ...@@ -5985,6 +5985,7 @@ def combine_frames(frames, ids, col_names=None):
schemes = {key: frames[ids[0]].schemes[key] for key in col_names} schemes = {key: frames[ids[0]].schemes[key] for key in col_names}
for frame_id in ids: for frame_id in ids:
frame = frames[frame_id] frame = frames[frame_id]
if frame.num_rows != 0:
for key, scheme in list(schemes.items()): for key, scheme in list(schemes.items()):
if key in frame.schemes: if key in frame.schemes:
if frame.schemes[key] != scheme: if frame.schemes[key] != scheme:
......
...@@ -1082,6 +1082,19 @@ def test_convert(idtype): ...@@ -1082,6 +1082,19 @@ def test_convert(idtype):
assert hg.device == g.device assert hg.device == g.device
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
@unittest.skipIf(F._default_context_str == 'gpu', reason="Test on cpu is enough")
@parametrize_dtype
def test_to_homo_zero_nodes(idtype):
# Fix gihub issue #2870
g = dgl.heterograph({
('A', 'AB', 'B'): (np.random.randint(0, 200, (1000,)), np.random.randint(0, 200, (1000,))),
('B', 'BA', 'A'): (np.random.randint(0, 200, (1000,)), np.random.randint(0, 200, (1000,))),
}, num_nodes_dict={'A': 200, 'B': 200, 'C': 0}, idtype=idtype)
g.nodes['A'].data['x'] = F.randn((200, 3))
g.nodes['B'].data['x'] = F.randn((200, 3))
gg = dgl.to_homogeneous(g, ['x'])
assert 'x' in gg.ndata
@parametrize_dtype @parametrize_dtype
def test_to_homo2(idtype): def test_to_homo2(idtype):
# test the result homogeneous graph has nodes and edges sorted by their types # test the result homogeneous graph has nodes and edges sorted by their types
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment