Unverified Commit 045beeba authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] remove toindex() as torch tensor is always be expected (#7146)

parent 63541c88
......@@ -2,12 +2,13 @@
import os
from .. import backend as F, utils
from .. import backend as F
from .dist_context import is_initialized
from .kvstore import get_kvstore
from .role import get_role
from .rpc import get_group_id
from .utils import totensor
def _default_init_data(shape, dtype):
......@@ -200,13 +201,11 @@ class DistTensor:
self.kvstore.delete_data(self._name)
def __getitem__(self, idx):
idx = utils.toindex(idx)
idx = idx.tousertensor()
idx = totensor(idx)
return self.kvstore.pull(name=self._name, id_tensor=idx)
def __setitem__(self, idx, val):
idx = utils.toindex(idx)
idx = idx.tousertensor()
idx = totensor(idx)
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self._name, id_tensor=idx, data_tensor=val)
......
......@@ -4,8 +4,9 @@ import pickle
from abc import ABC
import numpy as np
import torch
from .. import backend as F, utils
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from ..base import DGLError
from ..ndarray import exist_shared_mem_array
......@@ -761,7 +762,6 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_nid(self, ids, ntype):
"""Map per-node-type IDs to global node IDs in the homogeneous format."""
ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype)
typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
end_diff = F.gather_row(typed_max_nids, partids) - ids
......@@ -772,7 +772,6 @@ class RangePartitionBook(GraphPartitionBook):
def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format."""
ids = utils.toindex(ids).tousertensor()
c_etype = self.to_canonical_etype(etype)
partids = self.eid2partid(ids, c_etype)
typed_max_eids = F.zerocopy_from_numpy(
......@@ -786,32 +785,28 @@ class RangePartitionBook(GraphPartitionBook):
def nid2partid(self, nids, ntype=DEFAULT_NTYPE):
"""From global node IDs to partition IDs"""
nids = utils.toindex(nids)
# [TODO][Rui] replace numpy with torch.
nids = nids.numpy()
if ntype == DEFAULT_NTYPE:
ret = np.searchsorted(
self._max_node_ids, nids.tonumpy(), side="right"
)
ret = np.searchsorted(self._max_node_ids, nids, side="right")
else:
ret = np.searchsorted(
self._typed_max_node_ids[ntype], nids.tonumpy(), side="right"
self._typed_max_node_ids[ntype], nids, side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
return torch.from_numpy(ret)
def eid2partid(self, eids, etype=DEFAULT_ETYPE):
"""From global edge IDs to partition IDs"""
eids = utils.toindex(eids)
# [TODO][Rui] replace numpy with torch.
eids = eids.numpy()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
ret = np.searchsorted(
self._max_edge_ids, eids.tonumpy(), side="right"
)
ret = np.searchsorted(self._max_edge_ids, eids, side="right")
else:
c_etype = self.to_canonical_etype(etype)
ret = np.searchsorted(
self._typed_max_edge_ids[c_etype], eids.tonumpy(), side="right"
self._typed_max_edge_ids[c_etype], eids, side="right"
)
ret = utils.toindex(ret)
return ret.tousertensor()
return torch.from_numpy(ret)
def partid2nids(self, partid, ntype=DEFAULT_NTYPE):
"""From partition ID to global node IDs"""
......@@ -852,8 +847,6 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of nid2localnid."
)
nids = utils.toindex(nids)
nids = nids.tousertensor()
if ntype == DEFAULT_NTYPE:
start = self._max_node_ids[partid - 1] if partid > 0 else 0
else:
......@@ -870,8 +863,6 @@ class RangePartitionBook(GraphPartitionBook):
getting remote tensor of eid2localeid."
)
eids = utils.toindex(eids)
eids = eids.tousertensor()
if etype in (DEFAULT_ETYPE, DEFAULT_ETYPE[1]):
start = self._max_edge_ids[partid - 1] if partid > 0 else 0
else:
......
......@@ -14,7 +14,6 @@ from ..sampling import (
sample_neighbors as local_sample_neighbors,
)
from ..subgraph import in_subgraph as local_in_subgraph
from ..utils import toindex
from .rpc import (
recv_responses,
register_service,
......@@ -705,8 +704,6 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
if not isinstance(nodes, torch.Tensor):
nodes = toindex(nodes).tousertensor()
partition_id = partition_book.nid2partid(nodes)
local_nids = None
for pid in range(partition_book.num_partitions()):
......@@ -900,11 +897,7 @@ def sample_etype_neighbors(
), "The sampled node type {} does not exist in the input graph".format(
ntype
)
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
nodes = F.cat(homo_nids, 0)
def issue_remote_req(node_ids):
......@@ -1029,11 +1022,7 @@ def sample_neighbors(
assert (
ntype in g.ntypes
), "The sampled node type does not exist in the input graph"
if F.is_tensor(nodes[ntype]):
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nids.append(gpb.map_to_homo_nid(nodes[ntype], ntype))
nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
......@@ -1103,7 +1092,6 @@ def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""
req_list = []
partition_book = g.get_partition_book()
edges = toindex(edges).tousertensor()
partition_id = partition_book.eid2partid(edges)
local_eids = None
reorder_idx = []
......@@ -1221,7 +1209,6 @@ def in_subgraph(g, nodes):
def _distributed_get_node_property(g, n, issue_remote_req, local_access):
req_list = []
partition_book = g.get_partition_book()
n = toindex(n).tousertensor()
partition_id = partition_book.nid2partid(n)
local_nids = None
reorder_idx = []
......@@ -1266,7 +1253,21 @@ def _distributed_get_node_property(g, n, issue_remote_req, local_access):
def in_degrees(g, v):
"""Get in-degrees"""
"""Get in-degrees
Parameters
----------
g : DistGraph
The distributed graph.
v : tensor
The node ID array.
Returns
-------
tensor
The in-degree array.
"""
def issue_remote_req(v, order_id):
return InDegreeRequest(v, order_id)
......@@ -1278,7 +1279,21 @@ def in_degrees(g, v):
def out_degrees(g, u):
"""Get out-degrees"""
"""Get out-degrees
Parameters
----------
g : DistGraph
The distributed graph.
u : tensor
The node ID array.
Returns
-------
tensor
The out-degree array.
"""
def issue_remote_req(u, order_id):
return OutDegreeRequest(u, order_id)
......
......@@ -4,7 +4,7 @@ import os
import numpy as np
from .. import backend as F, utils
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from . import rpc
......@@ -1376,8 +1376,6 @@ class KVClient(object):
a vector storing the global data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
# partition data
machine_id = self._part_policy[name].to_partid(id_tensor)
......@@ -1399,8 +1397,6 @@ class KVClient(object):
a tensor with the same row size of data ID
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
assert (
F.shape(id_tensor)[0] == F.shape(data_tensor)[0]
......@@ -1452,8 +1448,6 @@ class KVClient(object):
a data tensor with the same row size of id_tensor.
"""
assert len(name) > 0, "name cannot be empty."
id_tensor = utils.toindex(id_tensor)
id_tensor = id_tensor.tousertensor()
assert F.ndim(id_tensor) == 1, "ID must be a vector."
if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
part_id = self._part_policy[name].to_partid(id_tensor)
......
......@@ -2,7 +2,7 @@
import torch as th
from .... import backend as F, utils
from .... import backend as F
from ...dist_tensor import DistTensor
......@@ -99,7 +99,7 @@ class DistEmbedding:
def __call__(self, idx, device=th.device("cpu")):
"""
node_ids : th.tensor
idx : th.tensor
Index of the embeddings to collect.
device : th.device
Target device to put the collected embeddings.
......@@ -109,7 +109,6 @@ class DistEmbedding:
Tensor
The requested node embeddings
"""
idx = utils.toindex(idx).tousertensor()
emb = self._tensor[idx].to(device, non_blocking=True)
if F.is_recording():
emb = F.attach_grad(emb)
......
""" Utility functions for distributed training."""
import torch
from ..utils import toindex
def totensor(data):
"""Convert the given data to a tensor.
Parameters
----------
data : tensor, array, list or slice
Data to be converted.
Returns
-------
Tensor
Converted tensor.
"""
if isinstance(data, torch.Tensor):
return data
return toindex(data).tousertensor()
......@@ -57,6 +57,8 @@ def check_binary_op(key1, key2, key3, op):
dist_g.edata[key3][i:i_end],
op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),
)
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]
_ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]
@unittest.skipIf(
......
......@@ -11,6 +11,7 @@ import backend as F
import dgl
import numpy as np
import pytest
import torch
from dgl.data import CitationGraphDataset, WN18Dataset
from dgl.distributed import (
DistGraph,
......@@ -56,7 +57,9 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
dist_graph = DistGraph("test_sampling", gpb=gpb)
try:
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
3,
)
except Exception as e:
print(traceback.format_exc())
......@@ -86,7 +89,10 @@ def start_sample_client_shuffle(
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
3,
use_graphbolt=use_graphbolt,
)
assert (
......@@ -502,7 +508,7 @@ def start_hetero_etype_sample_client(
tmpdir,
disable_shared_mem,
fanout=3,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
nodes=None,
etype_sorted=False,
use_graphbolt=False,
return_eids=False,
......@@ -590,11 +596,12 @@ def check_rpc_hetero_sampling_shuffle(
time.sleep(1)
pserver_list.append(p)
nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
block, gpb = start_hetero_sample_client(
0,
tmpdir,
num_server > 1,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -742,12 +749,13 @@ def check_rpc_hetero_etype_sampling_shuffle(
etype_sorted = False
if graph_formats is not None:
etype_sorted = "csc" in graph_formats or "csr" in graph_formats
nodes = {"n3": torch.tensor([0, 10, 99, 66, 124, 208], dtype=g.idtype)}
block, gpb = start_hetero_etype_sample_client(
0,
tmpdir,
num_server > 1,
fanout,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
nodes=nodes,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
......@@ -972,11 +980,12 @@ def check_rpc_bipartite_sampling_empty(
deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0)
nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
block, _ = start_bipartite_sample_client(
0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -1032,11 +1041,12 @@ def check_rpc_bipartite_sampling_shuffle(
deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0)
nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
block, gpb = start_bipartite_sample_client(
0,
tmpdir,
num_server > 1,
nodes={"game": nids, "user": [0]},
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -1111,11 +1121,12 @@ def check_rpc_bipartite_etype_sampling_empty(
deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0)
nodes = {"game": empty_nids, "user": torch.tensor([1], dtype=g.idtype)}
block, _ = start_bipartite_etype_sample_client(
0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -1173,12 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(
fanout = 3
deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0)
nodes = {"game": nids, "user": torch.tensor([0], dtype=g.idtype)}
block, gpb = start_bipartite_etype_sample_client(
0,
tmpdir,
num_server > 1,
fanout,
nodes={"game": nids, "user": [0]},
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -1383,7 +1395,11 @@ def check_standalone_sampling(tmpdir):
dist_graph = DistGraph(
"test_sampling", part_config=tmpdir / "test_sampling.json"
)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
sampled_graph = sample_neighbors(
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
3,
)
src, dst = sampled_graph.edges()
assert sampled_graph.num_nodes() == g.num_nodes()
......@@ -1394,13 +1410,19 @@ def check_standalone_sampling(tmpdir):
)
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="mask"
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
3,
prob="mask",
)
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert mask[eid].all()
sampled_graph = sample_neighbors(
dist_graph, [0, 10, 99, 66, 1024, 2008], 3, prob="prob"
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=dist_graph.idtype),
3,
prob="prob",
)
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert (prob[eid] > 0).all()
......@@ -1429,7 +1451,11 @@ def check_standalone_etype_sampling(tmpdir):
dist_graph = DistGraph(
"test_sampling", part_config=tmpdir / "test_sampling.json"
)
sampled_graph = sample_etype_neighbors(dist_graph, [0, 10, 99, 66, 1023], 3)
sampled_graph = sample_etype_neighbors(
dist_graph,
torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
3,
)
src, dst = sampled_graph.edges()
assert sampled_graph.num_nodes() == hg.num_nodes()
......@@ -1440,13 +1466,19 @@ def check_standalone_etype_sampling(tmpdir):
)
sampled_graph = sample_etype_neighbors(
dist_graph, [0, 10, 99, 66, 1023], 3, prob="mask"
dist_graph,
torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
3,
prob="mask",
)
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert mask[eid].all()
sampled_graph = sample_etype_neighbors(
dist_graph, [0, 10, 99, 66, 1023], 3, prob="prob"
dist_graph,
torch.tensor([0, 10, 99, 66, 1023], dtype=dist_graph.idtype),
3,
prob="prob",
)
eid = F.asnumpy(sampled_graph.edata[dgl.EID])
assert (prob[eid] > 0).all()
......@@ -1479,7 +1511,12 @@ def check_standalone_etype_sampling_heterograph(tmpdir):
"test_hetero_sampling", part_config=tmpdir / "test_hetero_sampling.json"
)
sampled_graph = sample_etype_neighbors(
dist_graph, [0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701], 1
dist_graph,
torch.tensor(
[0, 1, 2, 10, 99, 66, 1023, 1024, 2700, 2701],
dtype=dist_graph.idtype,
),
1,
)
src, dst = sampled_graph.edges(etype=("paper", "cite", "paper"))
assert len(src) == 10
......@@ -1547,7 +1584,7 @@ def check_rpc_in_subgraph_shuffle(tmpdir, num_server):
time.sleep(1)
pserver_list.append(p)
nodes = [0, 10, 99, 66, 1024, 2008]
nodes = torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=g.idtype)
sampled_graph = start_in_subgraph_client(0, tmpdir, num_server > 1, nodes)
for p in pserver_list:
p.join()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment