Unverified Commit cc73c60c authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Bug Fix] Fix Heterogenous graph save crash when some edge type names are identical. (#1751)



* Fix hetero serialize

* Add test case
y
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent feabcabd
......@@ -35,7 +35,7 @@ class HeteroGraphData(ObjectBase):
def create(g):
edata_list = []
ndata_list = []
for etype in g.etypes:
for etype in g.canonical_etypes:
edata_list.append(tensor_dict_to_ndarray_dict(g.edges[etype].data))
for ntype in g.ntypes:
ndata_list.append(tensor_dict_to_ndarray_dict(g.nodes[ntype].data))
......
......@@ -223,6 +223,17 @@ def create_heterographs(index_dtype):
g = dgl.hetero_from_relations([g_x, g_y])
return [g, g_x, g_y]
def create_heterographs2(index_dtype):
g_x = dgl.graph(([0, 1, 2], [1, 2, 3]), 'user',
'follows', index_dtype=index_dtype, restrict_format='any')
g_y = dgl.graph(([0, 2], [2, 3]), 'user', 'knows', index_dtype=index_dtype, restrict_format='csr')
g_z = dgl.bipartite(([0, 1, 3], [2, 3, 4]), 'user', 'knows', 'knowledge', index_dtype=index_dtype)
g_x.nodes['user'].data['h'] = F.randn((4, 3))
g_x.edges['follows'].data['w'] = F.randn((3, 2))
g_y.nodes['user'].data['hh'] = F.ones((4, 5))
g_y.edges['knows'].data['ww'] = F.randn((2, 10))
g = dgl.hetero_from_relations([g_x, g_y, g_z])
return [g, g_x, g_y]
def test_deserialize_old_heterograph_file():
path = os.path.join(
......@@ -251,11 +262,15 @@ def test_serialize_heterograph():
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
g_list0 = create_heterographs("int64") + create_heterographs("int32")
g_list0 = create_heterographs2("int64") + create_heterographs2("int32")
save_graphs(path, g_list0)
g_list, _ = load_graphs(path)
assert g_list[0].idtype == F.int64
assert len(g_list[0].canonical_etypes) == 3
for i in range(len(g_list0)):
for j, etypes in enumerate(g_list0[i].canonical_etypes):
assert g_list[i].canonical_etypes[j] == etypes
assert g_list[1].restrict_format() == 'any'
assert g_list[2].restrict_format() == 'csr'
assert g_list[3].idtype == F.int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment