Unverified Commit cd2cf606 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Bugfix] Fix bug introduced by https://github.com/dmlc/dgl/pull/3131 (#3234)

* Fix bug

* Fix

* Fix

* upd

* trigger
parent 799c091e
...@@ -15,6 +15,7 @@ from .. import backend as F ...@@ -15,6 +15,7 @@ from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
from .kvstore import KVServer, get_kvstore from .kvstore import KVServer, get_kvstore
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array
from ..frame import infer_scheme from ..frame import infer_scheme
from .partition import load_partition, load_partition_book from .partition import load_partition, load_partition_book
from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
...@@ -76,7 +77,10 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format): ...@@ -76,7 +77,10 @@ def _copy_graph_to_shared_mem(g, graph_name, graph_format):
new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'], new_g.edata['inner_edge'] = _to_shared_mem(g.edata['inner_edge'],
_get_edata_path(graph_name, 'inner_edge')) _get_edata_path(graph_name, 'inner_edge'))
new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID)) new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID))
new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE)) # for heterogeneous graph, we need to put ETYPE into KVStore
# for homogeneous graph, ETYPE does not exist
if ETYPE in g.edata:
new_g.edata[ETYPE] = _to_shared_mem(g.edata[ETYPE], _get_edata_path(graph_name, ETYPE))
return new_g return new_g
FIELD_DICT = {'inner_node': F.int32, # A flag indicates whether the node is inside a partition. FIELD_DICT = {'inner_node': F.int32, # A flag indicates whether the node is inside a partition.
...@@ -112,6 +116,9 @@ def _get_shared_mem_edata(g, graph_name, name): ...@@ -112,6 +116,9 @@ def _get_shared_mem_edata(g, graph_name, name):
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
return F.zerocopy_from_dlpack(dlpack) return F.zerocopy_from_dlpack(dlpack)
def _exist_shared_mem_array(graph_name, name):
return exist_shared_mem_array(_get_edata_path(graph_name, name))
def _get_graph_from_shared_mem(graph_name): def _get_graph_from_shared_mem(graph_name):
''' Get the graph from the DistGraph server. ''' Get the graph from the DistGraph server.
...@@ -129,7 +136,10 @@ def _get_graph_from_shared_mem(graph_name): ...@@ -129,7 +136,10 @@ def _get_graph_from_shared_mem(graph_name):
g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge') g.edata['inner_edge'] = _get_shared_mem_edata(g, graph_name, 'inner_edge')
g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID) g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)
# heterogeneous graph has ETYPE
if _exist_shared_mem_array(graph_name, ETYPE):
g.edata[ETYPE] = _get_shared_mem_edata(g, graph_name, ETYPE)
return g return g
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple('NodeSpace', ['data'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment