"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "535aa3d33abe407bbccc6955cc4915ad8d9396cb"
Unverified Commit 02d31974 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Pytorch example of distributed GraphSage. (#1495)



* add train_dist.

* Fix sampling example.

* use distributed sampler.

* fix a bug in DistTensor.

* fix distributed training example.

* add graph partition.

* add command

* disable pytorch parallel.

* shutdown correctly.

* load diff graphs.

* add ip_config.txt.

* record timing for each step.

* use ogb

* add profiler.

* fix a bug.

* add train_dist.

* Fix sampling example.

* use distributed sampler.

* fix a bug in DistTensor.

* fix distributed training example.

* add graph partition.

* add command

* disable pytorch parallel.

* shutdown correctly.

* load diff graphs.

* add ip_config.txt.

* record timing for each step.

* use ogb

* add profiler.

* add Ips of the cluster.

* fix exit.

* support multiple clients.

* balance node types and edges.

* move code.

* remove run.sh

* Revert "support multiple clients."

* fix.

* update train_sampling.

* fix.

* fix

* remove run.sh

* update readme.

* update readme.

* use pytorch distributed.

* ensure all trainers run the same number of steps.

* Update README.md
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-250.us-west-2.compute.internal>
parent 04d4680d
## Distributed training
This is an example of training GraphSage in a distributed fashion. To train GraphSage, it has four steps:
### Step 1: partition the graph.
The example provides a script to partition some builtin graphs such as Reddit and OGB product graph.
If we want to train GraphSage on 4 machines, we need to partition the graph into 4 parts.
We need to load some function from the parent directory.
```bash
export PYTHONPATH=$PYTHONPATH:..
```
In this example, we partition the OGB product graph into 4 parts with Metis. The partitions are balanced with respect to
the number of nodes, the number of edges and the number of labelled nodes.
```bash
# partition graph
python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train --balance_edges
```
### Step 2: copy the partitioned data to the cluster
When copying data to the cluster, we recommend users to copy the partitioned data to NFS so that all worker machines
will be able to access the partitioned data.
### Step 3: run servers
We need to run a server on each machine. Before running the servers, we need to update `ip_config.txt` with the right IP addresses.
```bash
# run server on machine 0
python3 train_dist.py --server --graph-name ogb-product --id 0 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
# run server on machine 1
python3 train_dist.py --server --graph-name ogb-product --id 1 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
# run server on machine 2
python3 train_dist.py --server --graph-name ogb-product --id 2 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
# run server on machine 3
python3 train_dist.py --server --graph-name ogb-product --id 3 --num-client 4 --conf_path data/ogb-product.json --ip_config ip_config.txt
```
### Step 4: run trainers
We run a trainer process on each machine. Here we use Pytorch distributed. We need to use pytorch distributed launch to run each trainer process.
Pytorch distributed requires one of the trainer process to be the master. Here we use the first machine to run the master process.
```bash
# run client on machine 0
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=0 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
# run client on machine 1
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=1 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
# run client on machine 2
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=2 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
# run client on machine 3
python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=4 --node_rank=3 --master_addr="172.31.16.250" --master_port=1234 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --num-client 4 --batch-size 1000 --lr 0.1
```
172.31.16.250 5555 1
172.31.30.135 5555 1
172.31.27.41 5555 1
172.31.30.149 5555 1
import dgl
import numpy as np
import torch as th
import argparse
import time
from load_graph import load_reddit, load_ogb
if __name__ == '__main__':
argparser = argparse.ArgumentParser("Partition builtin graphs")
argparser.add_argument('--dataset', type=str, default='reddit',
help='datasets: reddit, ogb-product, ogb-paper100M')
argparser.add_argument('--num_parts', type=int, default=4,
help='number of partitions')
argparser.add_argument('--balance_train', action='store_true',
help='balance the training size in each partition.')
argparser.add_argument('--balance_edges', action='store_true',
help='balance the number of edges in each partition.')
args = argparser.parse_args()
start = time.time()
if args.dataset == 'reddit':
g, _ = load_reddit()
elif args.dataset == 'ogb-product':
g, _ = load_ogb('ogbn-products')
elif args.dataset == 'ogb-paper100M':
g, _ = load_ogb('ogbn-papers100M')
print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start))
print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges()))
print('train: {}, valid: {}, test: {}'.format(th.sum(g.ndata['train_mask']),
th.sum(g.ndata['val_mask']),
th.sum(g.ndata['test_mask'])))
if args.balance_train:
balance_ntypes = g.ndata['train_mask']
else:
balance_ntypes = None
dgl.distributed.partition_graph(g, args.dataset, args.num_parts, 'data',
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges)
import os
os.environ['DGLBACKEND']='pytorch'
from multiprocessing import Process
import argparse, time, math
import numpy as np
from functools import wraps
import tqdm
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from pyinstrument import Profiler
from train_sampling import run, NeighborSampler, SAGE, compute_acc, evaluate, load_subtensor
def start_server(args):
serv = dgl.distributed.DistGraphServer(args.id, args.ip_config, args.num_client,
args.graph_name, args.conf_path)
serv.start()
def run(args, device, data):
# Unpack data
train_nid, val_nid, in_feats, n_classes, g = data
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
model = th.nn.parallel.DistributedDataParallel(model)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
train_size = th.sum(g.ndata['train_mask'][0:g.number_of_nodes()])
num_steps = int(args.num_epochs * train_size / args.batch_size / args.num_client)
# Training loop
iter_tput = []
profiler = Profiler()
profiler.start()
epoch = 0
while num_steps > 0:
tic = time.time()
sample_time = 0
copy_time = 0
forward_time = 0
backward_time = 0
update_time = 0
num_seeds = 0
num_inputs = 0
start = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
step_time = []
for step, blocks in enumerate(dataloader):
tic_step = time.time()
sample_time += tic_step - start
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels
start = time.time()
batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
copy_time += time.time() - start
num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
# Compute loss and prediction
start = time.time()
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
forward_end = time.time()
optimizer.zero_grad()
loss.backward()
compute_end = time.time()
forward_time += forward_end - start
backward_time += compute_end - forward_end
# Aggregate gradients in multiple nodes.
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= args.num_client
optimizer.step()
update_time += time.time() - compute_end
step_t = time.time() - tic_step
step_time.append(step_t)
iter_tput.append(num_seeds / (step_t))
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB | time {:.3f} s'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
start = time.time()
num_steps -= 1
# We have to ensure all trainer process run the same number of steps.
if num_steps == 0:
break
toc = time.time()
print('Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
toc - tic, sample_time, copy_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
epoch += 1
toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic))
#if epoch % args.eval_every == 0 and epoch != 0:
# eval_acc = evaluate(model, g, g.ndata['features'], g.ndata['labels'], val_nid, args.batch_size, device)
# print('Eval Acc {:.4f}'.format(eval_acc))
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
# clean up
g._client.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
def main(args):
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name)
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], g.get_partition_book(), g.rank())
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], g.get_partition_book(), g.rank())
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], g.get_partition_book(), g.rank())
print('part {}, train: {}, val: {}, test: {}'.format(g.rank(), len(train_nid),
len(val_nid), len(test_nid)))
device = th.device('cpu')
n_classes = len(th.unique(g.ndata['labels'][np.arange(g.number_of_nodes())]))
# Pack data
in_feats = g.ndata['features'].shape[1]
data = train_nid, val_nid, in_feats, n_classes, g
run(args, device, data)
print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument('--server', action='store_true',
help='whether this is a server.')
parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--conf_path', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--num-epochs', type=int, default=20)
parser.add_argument('--num-hidden', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--fan-out', type=str, default='10,25')
parser.add_argument('--batch-size', type=int, default=1000)
parser.add_argument('--log-every', type=int, default=20)
parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
args = parser.parse_args()
print(args)
if args.server:
start_server(args)
else:
main(args)
import dgl
import torch as th
def load_reddit():
from dgl.data import RedditDataset
# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
labels = th.LongTensor(data.labels)
# Construct graph
g = data.graph
g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = th.LongTensor(data.train_mask)
g.ndata['val_mask'] = th.LongTensor(data.val_mask)
g.ndata['test_mask'] = th.LongTensor(data.test_mask)
return g, data.num_labels
def load_ogb(name):
from ogb.nodeproppred import DglNodePropPredDataset
data = DglNodePropPredDataset(name=name)
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata['features'] = graph.ndata['feat']
graph.ndata['labels'] = labels
in_feats = graph.ndata['features'].shape[1]
num_labels = len(th.unique(labels))
# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.int64)
train_mask[train_nid] = 1
val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.int64)
val_mask[val_nid] = 1
test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.int64)
test_mask[test_nid] = 1
graph.ndata['train_mask'] = train_mask
graph.ndata['val_mask'] = val_mask
graph.ndata['test_mask'] = test_mask
return graph, len(th.unique(graph.ndata['labels']))
...@@ -16,19 +16,22 @@ from dgl.data import RedditDataset ...@@ -16,19 +16,22 @@ from dgl.data import RedditDataset
import tqdm import tqdm
import traceback import traceback
from load_graph import load_reddit, load_ogb
#### Neighbor sampler #### Neighbor sampler
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts): def __init__(self, g, fanouts, sample_neighbors):
self.g = g self.g = g
self.fanouts = fanouts self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds)) seeds = th.LongTensor(np.asarray(seeds))
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True) frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
...@@ -78,6 +81,7 @@ class SAGE(nn.Module): ...@@ -78,6 +81,7 @@ class SAGE(nn.Module):
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph. g : the entire graph.
x : the input of entire node set. x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and The inference code is written in a fashion that it could handle any number of nodes and
layers. layers.
""" """
...@@ -113,6 +117,7 @@ def prepare_mp(g): ...@@ -113,6 +117,7 @@ def prepare_mp(g):
Explicitly materialize the CSR, CSC and COO representation of the given graph Explicitly materialize the CSR, CSC and COO representation of the given graph
so that they could be shared via copy-on-write to sampler workers and GPU so that they could be shared via copy-on-write to sampler workers and GPU
trainers. trainers.
This is a workaround before full shared memory support on heterogeneous graphs. This is a workaround before full shared memory support on heterogeneous graphs.
""" """
g.in_degree(0) g.in_degree(0)
...@@ -125,13 +130,13 @@ def compute_acc(pred, labels): ...@@ -125,13 +130,13 @@ def compute_acc(pred, labels):
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, inputs, labels, val_mask, batch_size, device): def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph. g : The entire graph.
inputs : The features of all the nodes. inputs : The features of all the nodes.
labels : The labels of all the nodes. labels : The labels of all the nodes.
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for. val_nid : the node Ids for validation.
batch_size : Number of nodes to compute at the same time. batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on. device : The GPU device to evaluate on.
""" """
...@@ -139,27 +144,26 @@ def evaluate(model, g, inputs, labels, val_mask, batch_size, device): ...@@ -139,27 +144,26 @@ def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
with th.no_grad(): with th.no_grad():
pred = model.inference(g, inputs, batch_size, device) pred = model.inference(g, inputs, batch_size, device)
model.train() model.train()
return compute_acc(pred[val_mask], labels[val_mask]) return compute_acc(pred[val_nid], labels[val_nid])
def load_subtensor(g, labels, seeds, input_nodes, device): def load_subtensor(g, seeds, input_nodes, device):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
batch_inputs = g.ndata['features'][input_nodes].to(device) batch_inputs = g.ndata['features'][input_nodes].to(device)
batch_labels = labels[seeds].to(device) batch_labels = g.ndata['labels'][seeds].to(device)
return batch_inputs, batch_labels return batch_inputs, batch_labels
#### Entry point #### Entry point
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data train_mask, val_mask, in_feats, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0]) train_nid = th.nonzero(train_mask, as_tuple=True)[0]
val_nid = th.LongTensor(np.nonzero(val_mask)[0]) val_nid = th.nonzero(val_mask, as_tuple=True)[0]
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.sampling.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DataLoader(
...@@ -194,7 +198,7 @@ def run(args, device, data): ...@@ -194,7 +198,7 @@ def run(args, device, data):
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device) batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
...@@ -215,7 +219,7 @@ def run(args, device, data): ...@@ -215,7 +219,7 @@ def run(args, device, data):
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, device) eval_acc = evaluate(model, g, g.ndata['features'], g.ndata['labels'], val_nid, args.batch_size, device)
print('Eval Acc {:.4f}'.format(eval_acc)) print('Eval Acc {:.4f}'.format(eval_acc))
print('Avg epoch time: {}'.format(avg / (epoch - 4))) print('Avg epoch time: {}'.format(avg / (epoch - 4)))
...@@ -224,6 +228,7 @@ if __name__ == '__main__': ...@@ -224,6 +228,7 @@ if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training") help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument('--num-layers', type=int, default=2)
...@@ -242,19 +247,18 @@ if __name__ == '__main__': ...@@ -242,19 +247,18 @@ if __name__ == '__main__':
else: else:
device = th.device('cpu') device = th.device('cpu')
# load reddit data if args.dataset == 'reddit':
data = RedditDataset(self_loop=True) g, n_classes = load_reddit()
train_mask = data.train_mask elif args.dataset == 'ogb-product':
val_mask = data.val_mask g, n_classes = load_ogb('ogbn-products')
features = th.Tensor(data.features) else:
in_feats = features.shape[1] raise Exception('unknown dataset')
labels = th.LongTensor(data.labels) g = dgl.as_heterograph(g)
n_classes = data.num_labels in_feats = g.ndata['features'].shape[1]
# Construct graph train_mask = g.ndata['train_mask']
g = dgl.graph(data.graph.all_edges()) val_mask = g.ndata['val_mask']
g.ndata['features'] = features
prepare_mp(g) prepare_mp(g)
# Pack data # Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, in_feats, n_classes, g
run(args, device, data) run(args, device, data)
...@@ -126,9 +126,13 @@ class DistTensor: ...@@ -126,9 +126,13 @@ class DistTensor:
self._dtype = dtype self._dtype = dtype
def __getitem__(self, idx): def __getitem__(self, idx):
idx = utils.toindex(idx)
idx = idx.tousertensor()
return self.kvstore.pull(name=self.name, id_tensor=idx) return self.kvstore.pull(name=self.name, id_tensor=idx)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
idx = utils.toindex(idx)
idx = idx.tousertensor()
# TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1). # TODO(zhengda) how do we want to support broadcast (e.g., G.ndata['h'][idx] = 1).
self.kvstore.push(name=self.name, id_tensor=idx, data_tensor=val) self.kvstore.push(name=self.name, id_tensor=idx, data_tensor=val)
...@@ -490,7 +494,10 @@ class DistGraph: ...@@ -490,7 +494,10 @@ class DistGraph:
int int
The rank of the current graph store. The rank of the current graph store.
''' '''
return self._client.client_id if self._g is None:
return self._client.client_id
else:
return self._gpb.partid
def get_partition_book(self): def get_partition_book(self):
"""Get the partition information. """Get the partition information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment