"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "7facedda38da928843e9ed0de1810d45ce1b9224"
Unverified Commit 85e660cb authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Fix #1682 (#1683)

* upd

* upd to_hetero
parent 421763fb
...@@ -683,11 +683,6 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, ...@@ -683,11 +683,6 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE,
ntypes, ntype_count)}) ntypes, ntype_count)})
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)} ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.tensor(ntype2ngrp[ntype])
for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[etid][EID] = F.tensor(edge_groups[etid])
# features # features
for key, data in G.ndata.items(): for key, data in G.ndata.items():
...@@ -699,6 +694,12 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, ...@@ -699,6 +694,12 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE,
rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data)) rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))
hg._edge_frames[etid][key] = F.gather_row(data, rows) hg._edge_frames[etid][key] = F.gather_row(data, rows)
for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.tensor(ntype2ngrp[ntype])
for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[etid][EID] = F.tensor(edge_groups[etid])
return hg return hg
def to_homo(G): def to_homo(G):
...@@ -766,12 +767,8 @@ def to_homo(G): ...@@ -766,12 +767,8 @@ def to_homo(G):
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
validate=False, index_dtype=G._idtype_str) validate=False, index_dtype=G._idtype_str)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
retg.edata[EID] = F.cat(eids, 0)
# features # copy features
comb_nf = combine_frames(G._node_frames, range(len(G.ntypes))) comb_nf = combine_frames(G._node_frames, range(len(G.ntypes)))
comb_ef = combine_frames(G._edge_frames, range(len(G.etypes))) comb_ef = combine_frames(G._edge_frames, range(len(G.etypes)))
if comb_nf is not None: if comb_nf is not None:
...@@ -779,6 +776,12 @@ def to_homo(G): ...@@ -779,6 +776,12 @@ def to_homo(G):
if comb_ef is not None: if comb_ef is not None:
retg.edata.update(comb_ef) retg.edata.update(comb_ef)
# assign node type and id mapping field.
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
retg.edata[EID] = F.cat(eids, 0)
return retg return retg
############################################################ ############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment