Unverified Commit cd3fa030 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] init node/edge data store for Node/EdgeDataView in appropriate place (#4906)

* [Dist] instantiate NodeDataView in lazy mode

* fix test failure

* init node/edge data store at the very beginning

* fix test failures

* refine comment

* add more tests
parent 16eba6e8
......@@ -13,7 +13,7 @@ from ..convert import graph as dgl_graph
from ..transforms import compact_graphs
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID, ETYPE, ALL, is_all
from ..base import NID, EID, ETYPE, ALL, is_all, DGLError
from .kvstore import KVServer, get_kvstore
from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array
......@@ -176,24 +176,12 @@ class NodeDataView(MutableMapping):
def __init__(self, g, ntype=None):
self._graph = g
# When this is created, the server may already load node data. We need to
# initialize the node data in advance.
names = g._get_ndata_names(ntype)
if ntype is None:
if ntype is None or len(g.ntypes) == 1:
self._data = g._ndata_store
else:
if ntype in g._ndata_store:
self._data = g._ndata_store[ntype]
else:
self._data = {}
g._ndata_store[ntype] = self._data
for name in names:
assert name.is_node()
policy = PartitionPolicy(name.policy_str, g.get_partition_book())
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy, attach=False)
if ntype not in g.ntypes:
raise DGLError(f"Node type {ntype} does not exist.")
self._data = g._ndata_store[ntype]
def _get_names(self):
return list(self._data.keys())
......@@ -230,24 +218,11 @@ class EdgeDataView(MutableMapping):
def __init__(self, g, etype=None):
self._graph = g
# When this is created, the server may already load edge data. We need to
# initialize the edge data in advance.
names = g._get_edata_names(etype)
if etype is None:
if etype is None or len(g.canonical_etypes) == 1:
self._data = g._edata_store
else:
if etype in g._edata_store:
self._data = g._edata_store[etype]
else:
self._data = {}
g._edata_store[etype] = self._data
for name in names:
assert name.is_edge()
policy = PartitionPolicy(name.policy_str, g.get_partition_book())
dtype, shape, _ = g._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
self._data[name.get_name()] = DistTensor(shape, dtype, name.get_name(),
part_policy=policy, attach=False)
c_etype = g.to_canonical_etype(etype)
self._data = g._edata_store[c_etype]
def _get_names(self):
return list(self._data.keys())
......@@ -520,10 +495,8 @@ class DistGraph:
rpc.recv_response()
self._client.barrier()
self._ndata_store = {}
self._edata_store = {}
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._init_ndata_store()
self._init_edata_store()
self._num_nodes = 0
self._num_edges = 0
......@@ -545,6 +518,48 @@ class DistGraph:
self._gpb = gpb
self._client.map_shared_data(self._gpb)
def _init_ndata_store(self):
'''Initialize node data store.'''
self._ndata_store = {}
for ntype in self.ntypes:
names = self._get_ndata_names(ntype)
data = {}
for name in names:
assert name.is_node()
policy = PartitionPolicy(name.policy_str,
self.get_partition_book()
)
dtype, shape, _ = self._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
data[name.get_name()] = DistTensor(shape, dtype,
name.get_name(), part_policy=policy, attach=False
)
if len(self.ntypes) == 1:
self._ndata_store = data
else:
self._ndata_store[ntype] = data
def _init_edata_store(self):
'''Initialize edge data store.'''
self._edata_store = {}
for etype in self.canonical_etypes:
names = self._get_edata_names(etype)
data = {}
for name in names:
assert name.is_edge()
policy = PartitionPolicy(name.policy_str,
self.get_partition_book()
)
dtype, shape, _ = self._client.get_data_meta(str(name))
# We create a wrapper on the existing tensor in the kvstore.
data[name.get_name()] = DistTensor(shape, dtype,
name.get_name(), part_policy=policy, attach=False
)
if len(self.canonical_etypes) == 1:
self._edata_store = data
else:
self._edata_store[etype] = data
def __getstate__(self):
return self.graph_name, self._gpb
......@@ -552,10 +567,8 @@ class DistGraph:
self.graph_name, gpb = state
self._init(gpb)
self._ndata_store = {}
self._edata_store = {}
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._init_ndata_store()
self._init_edata_store()
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
......@@ -600,7 +613,7 @@ class DistGraph:
The data view in the distributed graph storage.
"""
assert len(self.ntypes) == 1, "ndata only works for a graph with one node type."
return self._ndata
return NodeDataView(self)
@property
def edata(self):
......@@ -612,7 +625,7 @@ class DistGraph:
The data view in the distributed graph storage.
"""
assert len(self.etypes) == 1, "edata only works for a graph with one edge type."
return self._edata
return EdgeDataView(self)
@property
def idtype(self):
......
......@@ -794,7 +794,7 @@ class BasicPartitionBook(GraphPartitionBook):
DEFAULT_ETYPE,
DEFAULT_ETYPE[1],
), "Base partition book only supports homogeneous graph."
return self.canonical_etypes
return self.canonical_etypes[0]
class RangePartitionBook(GraphPartitionBook):
......
......@@ -58,6 +58,7 @@ class KVClient(object):
def delete_data(self, name):
'''delete the data'''
del self._data[name]
self._gdata_name_list.remove(name)
def data_name_list(self):
'''get the names of all data'''
......
......@@ -398,7 +398,12 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
test3 = dgl.distributed.DistTensor(
new_shape, F.float32, "test3", init_func=rand_init
)
test3_name = test3.kvstore_key
assert test3_name in g._client.data_name_list()
assert test3_name in g._client.gdata_name_list()
del test3
assert test3_name not in g._client.data_name_list()
assert test3_name not in g._client.gdata_name_list()
test3 = dgl.distributed.DistTensor(
(g.number_of_nodes(), 3), F.float32, "test3"
)
......@@ -697,12 +702,20 @@ def create_random_hetero():
)
edges[etype] = (arr.row, arr.col)
g = dgl.heterograph(edges, num_nodes)
g.nodes["n1"].data["feat"] = F.unsqueeze(
F.arange(0, g.number_of_nodes("n1")), 1
)
g.edges["r1"].data["feat"] = F.unsqueeze(
F.arange(0, g.number_of_edges("r1")), 1
)
# assign ndata & edata.
# data with same name as ntype/etype is assigned on purpose to verify
# such same names can be correctly handled in DistGraph. See more details
# in issue #4887 and #4463 on github.
ntype = 'n1'
for name in ['feat', ntype]:
g.nodes[ntype].data[name] = F.unsqueeze(
F.arange(0, g.num_nodes(ntype)), 1
)
etype = 'r1'
for name in ['feat', etype]:
g.edges[etype].data[name] = F.unsqueeze(
F.arange(0, g.num_edges(etype)), 1
)
return g
......@@ -723,22 +736,40 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])
# Test reading node data
nids = F.arange(0, int(g.number_of_nodes("n1") / 2))
feats1 = g.nodes["n1"].data["feat"][nids]
feats = F.squeeze(feats1, 1)
assert np.all(F.asnumpy(feats == nids))
ntype = 'n1'
nids = F.arange(0, g.num_nodes(ntype) // 2)
for name in ['feat', ntype]:
data = g.nodes[ntype].data[name][nids]
data = F.squeeze(data, 1)
assert np.all(F.asnumpy(data == nids))
assert len(g.nodes['n2'].data) == 0
expect_except = False
try:
g.nodes['xxx'].data['x']
except dgl.DGLError:
expect_except = True
assert expect_except
# Test reading edge data
eids = F.arange(0, int(g.number_of_edges("r1") / 2))
# access via etype
feats = g.edges["r1"].data["feat"][eids]
feats = F.squeeze(feats, 1)
assert np.all(F.asnumpy(feats == eids))
# access via canonical etype
c_etype = g.to_canonical_etype("r1")
feats = g.edges[c_etype].data["feat"][eids]
feats = F.squeeze(feats, 1)
assert np.all(F.asnumpy(feats == eids))
etype = 'r1'
eids = F.arange(0, g.num_edges(etype) // 2)
for name in ['feat', etype]:
# access via etype
data = g.edges[etype].data[name][eids]
data = F.squeeze(data, 1)
assert np.all(F.asnumpy(data == eids))
# access via canonical etype
c_etype = g.to_canonical_etype(etype)
data = g.edges[c_etype].data[name][eids]
data = F.squeeze(data, 1)
assert np.all(F.asnumpy(data == eids))
assert len(g.edges['r2'].data) == 0
expect_except = False
try:
g.edges['xxx'].data['x']
except dgl.DGLError:
expect_except = True
assert expect_except
# Test edge_subgraph
sg = g.edge_subgraph({"r1": eids})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment