Unverified Commit a6505e86 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] revert toindex() but refine tests (#7197)

parent 97b66318
......@@ -2,13 +2,12 @@
import os
from .. import backend as F
from .. import backend as F, utils
from .dist_context import is_initialized
from .kvstore import get_kvstore
from .role import get_role
from .rpc import get_group_id
from .utils import totensor
def _default_init_data(shape, dtype):
......@@ -201,11 +200,13 @@ class DistTensor:
self.kvstore.delete_data(self._name)
def __getitem__(self, idx):
idx = totensor(idx)
idx = utils.toindex(idx)
idx = idx.tousertensor()
return self.kvstore.pull(name=self._name, id_tensor=idx)
def __setitem__(self, idx, val):
idx = totensor(idx)
idx = utils.toindex(idx)
idx = idx.tousertensor()
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)
......
......@@ -4,9 +4,8 @@ import pickle
from abc import ABC
import numpy as np
import torch
from .. import backend as F
from .. import backend as F, utils
from .._ffi.ndarray import empty_shared_mem
from ..base import DGLError
from ..ndarray import exist_shared_mem_array
......@@ -762,6 +761,7 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format."""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids
......@@ -772,6 +772,7 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
ids = utils.toindex(ids).tousertensor()
c_etype = self.to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy(
......@@ -785,28 +786,32 @@ class RangePartitionBook(GraphPartitionBook):
def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs"""
# [TODO][Rui] replace numpy with torch.
nids = nids.numpy()
nids = utils.toindex(nids)
if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(self._max_node_ids, nids, side="right")
ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
else:
ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids, side="right"
self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
)
return torch.from_numpy(ret)
ret = utils.toindex(ret)
return ret.tousertensor()
def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs"""
# [TODO][Rui] replace numpy with torch.
eids = eids.numpy()
eids = utils.toindex(eids)
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(self._max_edge_ids, eids, side="right")
ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
else:
c_etype = self.to_canonical_etype(etype)
ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids, side="right"
self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
)
return torch.from_numpy(ret)
ret = utils.toindex(ret)
return ret.tousertensor()
def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs"""
......@@ -847,6 +852,8 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of nid2localnid."
)
nids = utils.toindex(nids)
nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0
else:
......@@ -863,6 +870,8 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of eid2localeid."
)
eids = utils.toindex(eids)
eids = eids.tousertensor()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else:
......
......@@ -14,6 +14,7 @@ from ..sampling import (
sample_neighbors as local_sample_neighbors,
)
from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex
from .rpc import (
recv_responses,
register_service,
......@@ -707,6 +708,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
for pid in range(partition_book.num_partitions()):
......@@ -900,7 +903,11 @@ def sample_etype_neighbors(
), "The sampled node type {} does not exist in the input graph".format(
ntype
)
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids):
......@@ -1025,7 +1032,11 @@ def sample_neighbors(
assert (
ntype in g.ntypes
), "The sampled node type does not exist in the input graph"
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
......@@ -1095,6 +1106,7 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
edges = toindex(edges).tousertensor()
partition_id = partition_book.eid2partid(edges)
local_eids = None
reorder_idx = []
......@@ -1212,6 +1224,7 @@ def in_subgraph(g, nodes):
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = []
partition_book = g.get_partition_book()
n = toindex(n).tousertensor()
partition_id = partition_book.nid2partid(n)
local_nids = None
reorder_idx = []
......@@ -1256,21 +1269,7 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
def in_degrees(g, v):
"""Get in-degrees
Parameters
----------
g : DistGraph
The distributed graph.
v : tensor
The node ID array.
Returns
-------
tensor
The in-degree array.
"""
"""Get in-degrees"""
def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id)
......@@ -1282,21 +1281,7 @@ def in_degrees(g, v):
def out_degrees(g, u):
"""Get out-degrees
Parameters
----------
g : DistGraph
The distributed graph.
u : tensor
The node ID array.
Returns
-------
tensor
The out-degree array.
"""
"""Get out-degrees"""
def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id)
......
......@@ -4,7 +4,7 @@ import os
import numpy as np
from .. import backend as F
from .. import backend as F, utils
from .._ffi.ndarray import empty_shared_mem
from . import rpc
......@@ -1376,6 +1376,8 @@ class KVClient(object):
a vector storing the global data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
# partition data
machine_id = self._part_policy[name].to_partid(id_tensor)
......@@ -1397,6 +1399,8 @@ class KVClient(object):
a tensor with the same row size of data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
assert (
F.shape(id_tensor)[0] == F.shape(data_tensor)[0]
......@@ -1448,6 +1452,8 @@ class KVClient(object):
a data tensor with the same row size of id_tensor.
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor)
......
......@@ -2,7 +2,7 @@
import torch as th
from .... import backend as F
from .... import backend as F, utils
from ...dist_tensor import DistTensor
......@@ -99,7 +99,7 @@ class DistEmbedding:
def __call__(self, idx, device=th.device("cpu")):
"""
idx : th.tensor
node_ids : th.tensor
Index of the embeddings to collect.
device : th.device
Target device to put the collected embeddings.
......@@ -109,6 +109,7 @@ class DistEmbedding:
Tensor
The requested node embeddings
"""
idx = utils.toindex(idx).tousertensor()
emb = self._tensor[idx].to(device, non_blocking=True)
if F.is_recording():
emb = F.attach_grad(emb)
......
""" Utility functions for distributed training."""
import torch
from ..utils import toindex
def totensor(data):
"""Convert the given data to a tensor.
Parameters
----------
data : tensor, array, list or slice
Data to be converted.
Returns
-------
Tensor
Converted tensor.
"""
if isinstance(data, torch.Tensor):
return data
return toindex(data).tousertensor()
......@@ -57,6 +57,11 @@ def check_binary_op(key1, key2, key3, op):
dist_g.edata[key3][i:i_end],
op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),
)
# Test with different index dtypes. int32 is not supported.
with pytest.raises(
dgl.utils.internal.InconsistentDtypeException,
match="DGL now requires the input tensor to have",
):
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment