Unverified Commit a6505e86 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] revert toindex() but refine tests (#7197)

parent 97b66318
...@@ -2,13 +2,12 @@ ...@@ -2,13 +2,12 @@
import os import os
from .. import backend as F from .. import backend as F, utils
from .dist_context import is_initialized from .dist_context import is_initialized
from .kvstore import get_kvstore from .kvstore import get_kvstore
from .role import get_role from .role import get_role
from .rpc import get_group_id from .rpc import get_group_id
from .utils import totensor
def _default_init_data(shape, dtype): def _default_init_data(shape, dtype):
...@@ -201,11 +200,13 @@ class DistTensor: ...@@ -201,11 +200,13 @@ class DistTensor:
self.kvstore.delete_data(self._name) self.kvstore.delete_data(self._name)
def __getitem__(self, idx): def __getitem__(self, idx):
idx = totensor(idx) idx = utils.toindex(idx)
idx = idx.tousertensor()
return self.kvstore.pull(name=self._name, id_tensor=idx) return self.kvstore.pull(name=self._name, id_tensor=idx)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
idx = totensor(idx) idx = utils.toindex(idx)
idx = idx.tousertensor()
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1). # TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val) self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)
......
...@@ -4,9 +4,8 @@ import pickle ...@@ -4,9 +4,8 @@ import pickle
from abc import ABC from abc import ABC
import numpy as np import numpy as np
import torch
from .. import backend as F from .. import backend as F, utils
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..base import DGLError from ..base import DGLError
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
...@@ -762,6 +761,7 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -762,6 +761,7 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_nid(self, ids, ntype): def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format.""" """Map per-node-type IDs to global node IDs in the homogeneous format."""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype) partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype]) typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids end_diff = F.gather_row(typed_max_nids, partids) - ids
...@@ -772,6 +772,7 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -772,6 +772,7 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_eid(self, ids, etype): def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format.""" """Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
ids = utils.toindex(ids).tousertensor()
c_etype = self.to_canonical_etype(etype) c_etype = self.to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype) partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy( typed_max_eids = F.zerocopy_from_numpy(
...@@ -785,28 +786,32 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -785,28 +786,32 @@ class RangePartitionBook(GraphPartitionBook):
def nid2partid(self, nids, ntype=DEFAULT_NTYPE): def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs""" """From global node IDs to partition IDs"""
# [TODO][Rui] replace numpy with torch. nids = utils.toindex(nids)
nids = nids.numpy()
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(self._max_node_ids, nids, side="right") ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
else: else:
ret = np.searchsorted( ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids, side="right" self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
) )
return torch.from_numpy(ret) ret = utils.toindex(ret)
return ret.tousertensor()
def eid2partid(self, eids, etype=DEFAULT_ETYPE): def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs""" """From global edge IDs to partition IDs"""
# [TODO][Rui] replace numpy with torch. eids = utils.toindex(eids)
eids = eids.numpy()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]): if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(self._max_edge_ids, eids, side="right") ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
else: else:
c_etype = self.to_canonical_etype(etype) c_etype = self.to_canonical_etype(etype)
ret = np.searchsorted( ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids, side="right" self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
) )
return torch.from_numpy(ret) ret = utils.toindex(ret)
return ret.tousertensor()
def partid2nids(self, partid, ntype=DEFAULT_NTYPE): def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs""" """From partition ID to global node IDs"""
...@@ -847,6 +852,8 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -847,6 +852,8 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of nid2localnid." getting remote tensor of nid2localnid."
) )
nids = utils.toindex(nids)
nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE: if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0 start = self._max_node_ids[partid - 1] if partid > 0 else 0
else: else:
...@@ -863,6 +870,8 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -863,6 +870,8 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of eid2localeid." getting remote tensor of eid2localeid."
) )
eids = utils.toindex(eids)
eids = eids.tousertensor()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]): if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0 start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else: else:
......
...@@ -14,6 +14,7 @@ from ..sampling import ( ...@@ -14,6 +14,7 @@ from ..sampling import (
sample_neighbors as local_sample_neighbors, sample_neighbors as local_sample_neighbors,
) )
from ..subgraph import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex
from .rpc import ( from .rpc import (
recv_responses, recv_responses,
register_service, register_service,
...@@ -707,6 +708,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -707,6 +708,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
""" """
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes) partition_id = partition_book.nid2partid(nodes)
local_nids = None local_nids = None
for pid in range(partition_book.num_partitions()): for pid in range(partition_book.num_partitions()):
...@@ -900,7 +903,11 @@ def sample_etype_neighbors( ...@@ -900,7 +903,11 @@ def sample_etype_neighbors(
), "The sampled node type {} does not exist in the input graph".format( ), "The sampled node type {} does not exist in the input graph".format(
ntype ntype
) )
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype)) if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
...@@ -1025,7 +1032,11 @@ def sample_neighbors( ...@@ -1025,7 +1032,11 @@ def sample_neighbors(
assert ( assert (
ntype in g.ntypes ntype in g.ntypes
), "The sampled node type does not exist in the input graph" ), "The sampled node type does not exist in the input graph"
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype)) if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict): elif isinstance(nodes, dict):
assert len(nodes) == 1 assert len(nodes) == 1
...@@ -1095,6 +1106,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access): ...@@ -1095,6 +1106,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
""" """
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
edges = toindex(edges).tousertensor()
partition_id = partition_book.eid2partid(edges) partition_id = partition_book.eid2partid(edges)
local_eids = None local_eids = None
reorder_idx = [] reorder_idx = []
...@@ -1212,6 +1224,7 @@ def in_subgraph(g, nodes): ...@@ -1212,6 +1224,7 @@ def in_subgraph(g, nodes):
def _distributed_get_node_property(g, n, issue_remote_req, local_access): def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = [] req_list = []
partition_book = g.get_partition_book() partition_book = g.get_partition_book()
n = toindex(n).tousertensor()
partition_id = partition_book.nid2partid(n) partition_id = partition_book.nid2partid(n)
local_nids = None local_nids = None
reorder_idx = [] reorder_idx = []
...@@ -1256,21 +1269,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access): ...@@ -1256,21 +1269,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
def in_degrees(g, v): def in_degrees(g, v):
"""Get in-degrees """Get in-degrees"""
Parameters
----------
g : DistGraph
The distributed graph.
v : tensor
The node ID array.
Returns
-------
tensor
The in-degree array.
"""
def issue_remote_req(v, order_id): def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id) return InDegreeRequest(v, order_id)
...@@ -1282,21 +1281,7 @@ def in_degrees(g, v): ...@@ -1282,21 +1281,7 @@ def in_degrees(g, v):
def out_degrees(g, u): def out_degrees(g, u):
"""Get out-degrees """Get out-degrees"""
Parameters
----------
g : DistGraph
The distributed graph.
u : tensor
The node ID array.
Returns
-------
tensor
The out-degree array.
"""
def issue_remote_req(u, order_id): def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id) return OutDegreeRequest(u, order_id)
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, utils
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from . import rpc from . import rpc
...@@ -1376,6 +1376,8 @@ class KVClient(object): ...@@ -1376,6 +1376,8 @@ class KVClient(object):
a vector storing the global data ID a vector storing the global data ID
""" """
assert len(name) > 0, "name cannot be empty." assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector." assert F.ndim(id_tensor) == 1, "ID must be a vector."
# partition data # partition data
machine_id = self._part_policy[name].to_partid(id_tensor) machine_id = self._part_policy[name].to_partid(id_tensor)
...@@ -1397,6 +1399,8 @@ class KVClient(object): ...@@ -1397,6 +1399,8 @@ class KVClient(object):
a tensor with the same row size of data ID a tensor with the same row size of data ID
""" """
assert len(name) > 0, "name cannot be empty." assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector." assert F.ndim(id_tensor) == 1, "ID must be a vector."
assert ( assert (
F.shape(id_tensor)[0] == F.shape(data_tensor)[0] F.shape(id_tensor)[0] == F.shape(data_tensor)[0]
...@@ -1448,6 +1452,8 @@ class KVClient(object): ...@@ -1448,6 +1452,8 @@ class KVClient(object):
a data tensor with the same row size of id_tensor. a data tensor with the same row size of id_tensor.
""" """
assert len(name) > 0, "name cannot be empty." assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector." assert F.ndim(id_tensor) == 1, "ID must be a vector."
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor) part_id = self._part_policy[name].to_partid(id_tensor)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import torch as th import torch as th
from .... import backend as F from .... import backend as F, utils
from ...dist_tensor import DistTensor from ...dist_tensor import DistTensor
...@@ -99,7 +99,7 @@ class DistEmbedding: ...@@ -99,7 +99,7 @@ class DistEmbedding:
def __call__(self, idx, device=th.device("cpu")): def __call__(self, idx, device=th.device("cpu")):
""" """
idx : th.tensor node_ids : th.tensor
Index of the embeddings to collect. Index of the embeddings to collect.
device : th.device device : th.device
Target device to put the collected embeddings. Target device to put the collected embeddings.
...@@ -109,6 +109,7 @@ class DistEmbedding: ...@@ -109,6 +109,7 @@ class DistEmbedding:
Tensor Tensor
The requested node embeddings The requested node embeddings
""" """
idx = utils.toindex(idx).tousertensor()
emb = self._tensor[idx].to(device, non_blocking=True) emb = self._tensor[idx].to(device, non_blocking=True)
if F.is_recording(): if F.is_recording():
emb = F.attach_grad(emb) emb = F.attach_grad(emb)
......
""" Utility functions for distributed training."""
import torch
from ..utils import toindex
def totensor(data):
"""Convert the given data to a tensor.
Parameters
----------
data : tensor, array, list or slice
Data to be converted.
Returns
-------
Tensor
Converted tensor.
"""
if isinstance(data, torch.Tensor):
return data
return toindex(data).tousertensor()
...@@ -57,6 +57,11 @@ def check_binary_op(key1, key2, key3, op): ...@@ -57,6 +57,11 @@ def check_binary_op(key1, key2, key3, op):
dist_g.edata[key3][i:i_end], dist_g.edata[key3][i:i_end],
op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]), op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),
) )
# Test with different index dtypes. int32 is not supported.
with pytest.raises(
dgl.utils.internal.InconsistentDtypeException,
match="DGL now requires the input tensor to have",
):
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)] _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)] _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment