Unverified Commit 32d1f3ac authored by Rhett-Ying's avatar Rhett-Ying Committed by GitHub
Browse files

[BugFix] skip frames whose num_rows is zero in dgl.heterograph.combine_frames() (#3110)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 751685a9
...@@ -6081,20 +6081,23 @@ def combine_frames(frames, ids, col_names=None): ...@@ -6081,20 +6081,23 @@ def combine_frames(frames, ids, col_names=None):
The resulting frame The resulting frame
""" """
# find common columns and check if their schemes match # find common columns and check if their schemes match
if col_names is None: schemes = None
schemes = {key: scheme for key, scheme in frames[ids[0]].schemes.items()}
else:
schemes = {key: frames[ids[0]].schemes[key] for key in col_names}
for frame_id in ids: for frame_id in ids:
frame = frames[frame_id] frame = frames[frame_id]
if frame.num_rows != 0: if frame.num_rows == 0:
for key, scheme in list(schemes.items()): continue
if key in frame.schemes: if schemes is None:
if frame.schemes[key] != scheme: schemes = frame.schemes
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' % if col_names is not None:
(key, frame.schemes[key], scheme)) schemes = {key: frame.schemes[key] for key in col_names}
else: continue
del schemes[key] for key, scheme in list(schemes.items()):
if key in frame.schemes:
if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
(key, frame.schemes[key], scheme))
else:
del schemes[key]
if len(schemes) == 0: if len(schemes) == 0:
return None return None
......
...@@ -1082,6 +1082,22 @@ def test_convert(idtype): ...@@ -1082,6 +1082,22 @@ def test_convert(idtype):
assert hg.device == g.device assert hg.device == g.device
assert g.number_of_nodes() == 5 assert g.number_of_nodes() == 5
# hetero_to_subgraph_to_homo
hg = dgl.heterograph({
('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])
}, idtype=idtype, device=F.ctx())
hg.nodes['user'].data['h'] = F.copy_to(
F.tensor([[1, 0], [0, 1], [1, 1]], dtype=idtype), ctx=F.ctx())
sg = dgl.node_subgraph(hg, {'user': [1, 2]})
assert len(sg.ntypes) == 2
assert len(sg.etypes) == 2
assert sg.num_nodes('user') == 2
assert sg.num_nodes('game') == 0
g = dgl.to_homogeneous(sg, ndata=['h'])
assert 'h' in g.ndata.keys()
assert g.num_nodes() == 2
@unittest.skipIf(F._default_context_str == 'gpu', reason="Test on cpu is enough") @unittest.skipIf(F._default_context_str == 'gpu', reason="Test on cpu is enough")
@parametrize_dtype @parametrize_dtype
def test_to_homo_zero_nodes(idtype): def test_to_homo_zero_nodes(idtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment