"...sampling/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "391f513e58a05d60f2ca177280ccc79d8216b69a"
Unverified Commit e5b10b10 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Performance] fixing performance on heterogeneous graph creation in specific cases (#1506)

* [Performance] fixing performance on heterogeneous graph creation in specific cases

* fix for mxnet
parent 9b34b1c2
...@@ -448,15 +448,19 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'): ...@@ -448,15 +448,19 @@ def heterograph(data_dict, num_nodes_dict=None, index_dtype='int64'):
num_nodes_dict = defaultdict(int) num_nodes_dict = defaultdict(int)
for (srctype, etype, dsttype), data in data_dict.items(): for (srctype, etype, dsttype), data in data_dict.items():
if isinstance(data, tuple): if isinstance(data, tuple):
nsrc = (max(data[0]) + 1) if len(data[0]) > 0 else 0 src = utils.toindex(data[0]).tonumpy()
ndst = (max(data[1]) + 1) if len(data[1]) > 0 else 0 dst = utils.toindex(data[1]).tonumpy()
nsrc = (src.max() + 1) if len(src) > 0 else 0
ndst = (dst.max() + 1) if len(dst) > 0 else 0
elif isinstance(data, list): elif isinstance(data, list):
if len(data) == 0: if len(data) == 0:
nsrc = ndst = 0 nsrc = ndst = 0
else: else:
src, dst = zip(*data) src, dst = zip(*data)
nsrc = max(src) + 1 src = utils.toindex(src).tonumpy()
ndst = max(dst) + 1 dst = utils.toindex(dst).tonumpy()
nsrc = src.max() + 1
ndst = dst.max() + 1
elif isinstance(data, sp.sparse.spmatrix): elif isinstance(data, sp.sparse.spmatrix):
nsrc = data.shape[0] nsrc = data.shape[0]
ndst = data.shape[1] ndst = data.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment