"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "5f0532daea1f4bdf5e2c511361b484cdb002b50b"
Unverified Commit 56bbf9cb authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Distributed][Example] Distributed training for unsupervised graphsage (#1853)



* Standalone can run

* fix

* Add save

* Fix

* Fix

* Fix

* Fix

* debug

* test

* test

* Fix

* Fix

* log

* Fix

* fix

* Profile

* auto sync grad

* update

* add test for unsupervised dist training

* upd

* Fix lr

* Fix update

* sync

* fix

* Revert "fix"

This reverts commit d5caa7398b36125f6d6e2c742a95c6ff4298c9e9.

* Fix

* unsupervised

* Fix

* remove debug

* Add test case for dist_graph find_edges()

* Fix

* skip tensorflow test for find_edges

* Update readme

* remove some test

* upd

* Update partition_graph.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-24-210.ec2.internal>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 34a067ea
...@@ -15,7 +15,6 @@ export PYTHONPATH=$PYTHONPATH:.. ...@@ -15,7 +15,6 @@ export PYTHONPATH=$PYTHONPATH:..
In this example, we partition the OGB product graph into 4 parts with Metis. The partitions are balanced with respect to In this example, we partition the OGB product graph into 4 parts with Metis. The partitions are balanced with respect to
the number of nodes, the number of edges and the number of labelled nodes. the number of nodes, the number of edges and the number of labelled nodes.
```bash ```bash
# partition graph
python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train --balance_edges python3 partition_graph.py --dataset ogb-product --num_parts 4 --balance_train --balance_edges
``` ```
...@@ -50,6 +49,17 @@ python3 ~/dgl/tools/launch.py \ ...@@ -50,6 +49,17 @@ python3 ~/dgl/tools/launch.py \
"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000" "python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000"
``` ```
To run unsupervised training:
```bash
python3 ~/dgl/tools/launch.py \
--workspace ~/dgl/examples/pytorch/graphsage/experimental \
--num_client 4 \
--conf_path data/ogb-product.json \
--ip_config ip_config.txt \
"python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --num-client 4"
```
## Distributed code runs in the standalone mode ## Distributed code runs in the standalone mode
The standalone mode is mainly used for development and testing. The procedure to run the code is much simpler. The standalone mode is mainly used for development and testing. The procedure to run the code is much simpler.
...@@ -63,8 +73,16 @@ python3 partition_graph.py --dataset ogb-product --num_parts 1 ...@@ -63,8 +73,16 @@ python3 partition_graph.py --dataset ogb-product --num_parts 1
### Step 2: run the training script ### Step 2: run the training script
To run supervised training:
```bash ```bash
python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone
``` ```
To run unsupervised training:
```bash
python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --conf_path data/ogb-product.json --standalone
```
Note: please ensure that all environment variables shown above are unset if they were set for testing distributed training. Note: please ensure that all environment variables shown above are unset if they were set for testing distributed training.
import os
os.environ['DGLBACKEND']='pytorch'
from multiprocessing import Process
import argparse, time, math
import numpy as np
from functools import wraps
import tqdm
import sklearn.linear_model as lm
import sklearn.metrics as skm
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
#from pyinstrument import Profiler
from train_sampling import SAGE
class NegativeSampler(object):
def __init__(self, g, neg_nseeds):
self.neg_nseeds = neg_nseeds
def __call__(self, num_samples):
# select local neg nodes as seeds
return self.neg_nseeds[th.randint(self.neg_nseeds.shape[0], (num_samples,))]
class NeighborSampler(object):
def __init__(self, g, fanouts, neg_nseeds, sample_neighbors, num_negs, remove_edge):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
self.neg_sampler = NegativeSampler(g, neg_nseeds)
self.num_negs = num_negs
self.remove_edge = remove_edge
def sample_blocks(self, seed_edges):
n_edges = len(seed_edges)
seed_edges = th.LongTensor(np.asarray(seed_edges))
heads, tails = self.g.find_edges(seed_edges)
neg_tails = self.neg_sampler(self.num_negs * n_edges)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
# Maintain the correspondence between heads, tails and negative tails as two
# graphs.
# pos_graph contains the correspondence between each head and its positive tail.
# neg_graph contains the correspondence between each head and its negative tails.
# Both pos_graph and neg_graph are first constructed with the same node space as
# the original graph. Then they are compacted together with dgl.compact_graphs.
pos_graph = dgl.graph((heads, tails), num_nodes=self.g.number_of_nodes())
neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes())
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID]
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
if self.remove_edge:
# Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]),
th.cat([tails, heads, neg_tails, neg_heads]),
return_uv=True)
frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks
class PosNeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class DistSAGE(SAGE):
def __init__(self, in_feats, n_hidden, n_classes, n_layers,
activation, dropout):
super(DistSAGE, self).__init__(in_feats, n_hidden, n_classes, n_layers,
activation, dropout)
def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()),
g.get_partition_book(), force_even=True)
y = dgl.distributed.DistTensor(g, (g.number_of_nodes(), self.n_hidden), th.float32, 'h',
persistent=True)
for l, layer in enumerate(self.layers):
if l == len(self.layers) - 1:
y = dgl.distributed.DistTensor(g, (g.number_of_nodes(), self.n_classes),
th.float32, 'h_last', persistent=True)
sampler = PosNeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=nodes,
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=False,
num_workers=args.num_workers)
for blocks in tqdm.tqdm(dataloader):
block = blocks[0]
input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h.cpu()
x = y
g.barrier()
return y
def load_subtensor(g, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
return batch_inputs
class CrossEntropyLoss(nn.Module):
def forward(self, block_outputs, pos_graph, neg_graph):
with pos_graph.local_scope():
pos_graph.ndata['h'] = block_outputs
pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
pos_score = pos_graph.edata['score']
with neg_graph.local_scope():
neg_graph.ndata['h'] = block_outputs
neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
neg_score = neg_graph.edata['score']
score = th.cat([pos_score, neg_score])
label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss
def generate_emb(model, g, inputs, batch_size, device):
"""
Generate embeddings for each node
g : The entire graph.
inputs : The features of all the nodes.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
return pred
def compute_acc(emb, labels, train_nids, val_nids, test_nids):
"""
Compute the accuracy of prediction given the labels.
We will fist train a LogisticRegression model using the trained embeddings,
the training set, validation set and test set is provided as the arguments.
The final result is predicted by the lr model.
emb: The pretrained embeddings
labels: The ground truth
train_nids: The training set node ids
val_nids: The validation set node ids
test_nids: The test set node ids
"""
emb = emb[np.arange(labels.shape[0])].cpu().numpy()
train_nids = train_nids.cpu().numpy()
val_nids = val_nids.cpu().numpy()
test_nids = test_nids.cpu().numpy()
labels = labels.cpu().numpy()
emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)
lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000)
lr.fit(emb[train_nids], labels[train_nids])
pred = lr.predict(emb)
eval_acc = skm.accuracy_score(labels[val_nids], pred[val_nids])
test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids])
return eval_acc, test_acc
def run(args, device, data):
# Unpack data
train_eids, train_nids, in_feats, g, global_train_nid, global_valid_nid, global_test_nid, labels = data
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], train_nids,
dgl.distributed.sample_neighbors, args.num_negs, args.remove_edge)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=train_eids.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
# Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
model = model.to(device)
if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model)
loss_fcn = CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Training loop
#profiler = Profiler()
#profiler.start()
epoch = 0
for epoch in range(args.num_epochs):
sample_time = 0
copy_time = 0
forward_time = 0
backward_time = 0
update_time = 0
num_seeds = 0
num_inputs = 0
step_time = []
iter_t = []
sample_t = []
feat_copy_t = []
forward_t = []
backward_t = []
update_t = []
iter_tput = []
start = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
tic_step = time.time()
sample_t.append(tic_step - start)
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
# Load the input features as well as output labels
batch_inputs = load_subtensor(g, input_nodes, device)
copy_time = time.time()
feat_copy_t.append(copy_time - tic_step)
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, pos_graph, neg_graph)
forward_end = time.time()
optimizer.zero_grad()
loss.backward()
compute_end = time.time()
forward_t.append(forward_end - copy_time)
backward_t.append(compute_end - forward_end)
# Aggregate gradients in multiple nodes.
optimizer.step()
update_t.append(time.time() - compute_end)
pos_edges = pos_graph.number_of_edges()
neg_edges = neg_graph.number_of_edges()
step_t = time.time() - start
step_time.append(step_t)
iter_tput.append(pos_edges / step_t)
num_seeds += pos_edges
if step % args.log_every == 0:
print('[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | time {:.3f} s' \
'| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}'.format(
g.rank(), epoch, step, loss.item(), np.mean(iter_tput[3:]), np.sum(step_time[-args.log_every:]),
np.sum(sample_t[-args.log_every:]), np.sum(feat_copy_t[-args.log_every:]), np.sum(forward_t[-args.log_every:]),
np.sum(backward_t[-args.log_every:]), np.sum(update_t[-args.log_every:])))
start = time.time()
print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), num_seeds, num_inputs))
epoch += 1
# evaluate the embedding using LogisticRegression
if args.standalone:
pred = generate_emb(model,g, g.ndata['features'], args.batch_size_eval, device)
else:
pred = generate_emb(model.module, g, g.ndata['features'], args.batch_size_eval, device)
if g.rank() == 0:
eval_acc, test_acc = compute_acc(pred, labels, global_train_nid, global_valid_nid, global_test_nid)
print('eval acc {:.4f}; test acc {:.4f}'.format(eval_acc, test_acc))
# sync for eval and test
if not args.standalone:
th.distributed.barrier()
if not args.standalone:
g._client.barrier()
# save features into file
if g.rank() == 0:
th.save(pred, 'emb.pt')
else:
feat = g.ndata['features']
th.save(pred, 'emb.pt')
def main(args):
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, conf_file=args.conf_path)
print('rank:', g.rank())
print('number of edges', g.number_of_edges())
train_eids = dgl.distributed.edge_split(th.ones((g.number_of_edges(),), dtype=th.bool), g.get_partition_book(), force_even=True)
train_nids = dgl.distributed.node_split(th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book())
global_train_nid = th.LongTensor(np.nonzero(g.ndata['train_mask'][np.arange(g.number_of_nodes())]))
global_valid_nid = th.LongTensor(np.nonzero(g.ndata['val_mask'][np.arange(g.number_of_nodes())]))
global_test_nid = th.LongTensor(np.nonzero(g.ndata['test_mask'][np.arange(g.number_of_nodes())]))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
device = th.device('cpu')
# Pack data
in_feats = g.ndata['features'].shape[1]
global_train_nid = global_train_nid.squeeze()
global_valid_nid = global_valid_nid.squeeze()
global_test_nid = global_test_nid.squeeze()
print("number of train {}".format(global_train_nid.shape[0]))
print("number of valid {}".format(global_valid_nid.shape[0]))
print("number of test {}".format(global_test_nid.shape[0]))
data = train_eids, train_nids, in_feats, g, global_train_nid, global_valid_nid, global_test_nid, labels
run(args, device, data)
print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--conf_path', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
parser.add_argument('--num-epochs', type=int, default=20)
parser.add_argument('--num-hidden', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--fan-out', type=str, default='10,25')
parser.add_argument('--batch-size', type=int, default=1000)
parser.add_argument('--batch-size-eval', type=int, default=100000)
parser.add_argument('--log-every', type=int, default=20)
parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--num-negs', type=int, default=1)
parser.add_argument('--neg-share', default=False, action='store_true',
help="sharing neg nodes for positive nodes")
parser.add_argument('--remove-edge', default=False, action='store_true',
help="whether to remove edges during sampling")
args = parser.parse_args()
print(args)
main(args)
...@@ -64,7 +64,7 @@ class NeighborSampler(object): ...@@ -64,7 +64,7 @@ class NeighborSampler(object):
neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes()) neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes())
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph]) pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
# Obtain the node IDs being used in either pos_graph or neg_graph. Since they # Obtain the node IDs being used in either pos_graph or neg_graph. Since they
# are compacted together, pos_graph and neg_graph share the same compacted node # are compacted together, pos_graph and neg_graph share the same compacted node
# space. # space.
seeds = pos_graph.ndata[dgl.NID] seeds = pos_graph.ndata[dgl.NID]
...@@ -381,7 +381,7 @@ if __name__ == '__main__': ...@@ -381,7 +381,7 @@ if __name__ == '__main__':
argparser.add_argument('--num-workers', type=int, default=0, argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.") help="Number of sampling processes. Use 0 for no extra process.")
args = argparser.parse_args() args = argparser.parse_args()
devices = list(map(int, args.gpu.split(','))) devices = list(map(int, args.gpu.split(',')))
main(args, devices) main(args, devices)
...@@ -12,7 +12,7 @@ from .rpc_server import start_server ...@@ -12,7 +12,7 @@ from .rpc_server import start_server
from .rpc_client import connect_to_server, exit_client from .rpc_client import connect_to_server, exit_client
from .kvstore import KVServer, KVClient from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .graph_services import sample_neighbors, in_subgraph from .graph_services import sample_neighbors, in_subgraph, find_edges
if os.environ.get('DGL_ROLE', 'client') == 'server': if os.environ.get('DGL_ROLE', 'client') == 'server':
assert os.environ.get('DGL_SERVER_ID') is not None, \ assert os.environ.get('DGL_SERVER_ID') is not None, \
......
...@@ -20,6 +20,7 @@ from . import rpc ...@@ -20,6 +20,7 @@ from . import rpc
from .rpc_client import connect_to_server from .rpc_client import connect_to_server
from .server_state import ServerState from .server_state import ServerState
from .rpc_server import start_server from .rpc_server import start_server
from .graph_services import find_edges as dist_find_edges
from .dist_tensor import DistTensor, _get_data_name from .dist_tensor import DistTensor, _get_data_name
def _copy_graph_to_shared_mem(g, graph_name): def _copy_graph_to_shared_mem(g, graph_name):
...@@ -443,6 +444,25 @@ class DistGraph: ...@@ -443,6 +444,25 @@ class DistGraph:
client_id_in_part = rpc.get_rank() % num_client_per_part client_id_in_part = rpc.get_rank() % num_client_per_part
return int(self._gpb.partid * num_client_per_part + client_id_in_part) return int(self._gpb.partid * num_client_per_part + client_id_in_part)
def find_edges(self, edges):
""" Given an edge ID array, return the source
and destination node ID array ``s`` and ``d``. ``s[i]`` and ``d[i]``
are source and destination node ID for edge ``eid[i]``.
Parameters
----------
edges : tensor
The edge ID array.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
"""
return dist_find_edges(self, edges)
def get_partition_book(self): def get_partition_book(self):
"""Get the partition information. """Get the partition information.
......
...@@ -4,16 +4,17 @@ from collections import namedtuple ...@@ -4,16 +4,17 @@ from collections import namedtuple
from .rpc import Request, Response, send_requests_to_machine, recv_responses from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from ..transform import in_subgraph as local_in_subgraph from ..transform import in_subgraph as local_in_subgraph
from . import register_service from .rpc import register_service
from ..convert import graph from ..convert import graph
from ..base import NID, EID from ..base import NID, EID
from ..utils import toindex from ..utils import toindex
from .. import backend as F from .. import backend as F
__all__ = ['sample_neighbors', 'in_subgraph'] __all__ = ['sample_neighbors', 'in_subgraph', 'find_edges']
SAMPLING_SERVICE_ID = 6657 SAMPLING_SERVICE_ID = 6657
INSUBGRAPH_SERVICE_ID = 6658 INSUBGRAPH_SERVICE_ID = 6658
EDGES_SERVICE_ID = 6659
class SubgraphResponse(Response): class SubgraphResponse(Response):
"""The response for sampling and in_subgraph""" """The response for sampling and in_subgraph"""
...@@ -29,6 +30,19 @@ class SubgraphResponse(Response): ...@@ -29,6 +30,19 @@ class SubgraphResponse(Response):
def __getstate__(self): def __getstate__(self):
return self.global_src, self.global_dst, self.global_eids return self.global_src, self.global_dst, self.global_eids
class FindEdgeResponse(Response):
"""The response for sampling and in_subgraph"""
def __init__(self, global_src, global_dst, order_id):
self.global_src = global_src
self.global_dst = global_dst
self.order_id = order_id
def __setstate__(self, state):
self.global_src, self.global_dst, self.order_id = state
def __getstate__(self):
return self.global_src, self.global_dst, self.order_id
def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace): def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace):
""" Sample from local partition. """ Sample from local partition.
...@@ -49,6 +63,17 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr ...@@ -49,6 +63,17 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID])
return global_src, global_dst, global_eids return global_src, global_dst, global_eids
def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
"""
local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
local_eids = F.astype(local_eids, local_g.idtype)
local_src, local_dst = local_g.find_edges(local_eids)
global_nid_mapping = local_g.ndata[NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst
def _in_subgraph(local_g, partition_book, seed_nodes): def _in_subgraph(local_g, partition_book, seed_nodes):
""" Get in subgraph from local partition. """ Get in subgraph from local partition.
...@@ -94,6 +119,25 @@ class SamplingRequest(Request): ...@@ -94,6 +119,25 @@ class SamplingRequest(Request):
self.prob, self.replace) self.prob, self.replace)
return SubgraphResponse(global_src, global_dst, global_eids) return SubgraphResponse(global_src, global_dst, global_eids)
class EdgesRequest(Request):
"""Edges Request"""
def __init__(self, edge_ids, order_id):
self.edge_ids = edge_ids
self.order_id = order_id
def __setstate__(self, state):
self.edge_ids, self.order_id = state
def __getstate__(self):
return self.edge_ids, self.order_id
def process_request(self, server_state):
local_g = server_state.graph
partition_book = server_state.partition_book
global_src, global_dst = _find_edges(local_g, partition_book, self.edge_ids)
return FindEdgeResponse(global_src, global_dst, self.order_id)
class InSubgraphRequest(Request): class InSubgraphRequest(Request):
"""InSubgraph Request""" """InSubgraph Request"""
...@@ -249,6 +293,98 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -249,6 +293,98 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
fanout, edge_dir, prob, replace) fanout, edge_dir, prob, replace)
return _distributed_access(g, nodes, issue_remote_req, local_access) return _distributed_access(g, nodes, issue_remote_req, local_access)
def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""A routine that fetches local edges from distributed graph.
The source and destination nodes of local edges are stored in the local
machine and others are stored on remote machines. This code will issue
remote access requests first before fetching data from the local machine.
In the end, we combine the data from the local machine and remote machines.
Parameters
----------
g : DistGraph
The distributed graph
edges : tensor
The edges to find their source and destination nodes.
issue_remote_req : callable
The function that issues requests to access remote data.
local_access : callable
The function that reads data on the local machine.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
"""
req_list = []
partition_book = g.get_partition_book()
edges = toindex(edges).tousertensor()
partition_id = partition_book.eid2partid(edges)
local_eids = None
reorder_idx = []
for pid in range(partition_book.num_partitions()):
mask = (partition_id == pid)
edge_id = F.boolean_mask(edges, mask)
reorder_idx.append(F.nonzero_1d(mask))
if pid == partition_book.partid and g.local_partition is not None:
assert local_eids is None
local_eids = edge_id
elif len(edge_id) != 0:
req = issue_remote_req(edge_id, pid)
req_list.append((pid, req))
# send requests to the remote machine.
msgseq2pos = None
if len(req_list) > 0:
msgseq2pos = send_requests_to_machine(req_list)
# handle edges in local partition.
src_ids = F.zeros_like(edges)
dst_ids = F.zeros_like(edges)
if local_eids is not None:
src, dst = local_access(g.local_partition, partition_book, local_eids)
src_ids = F.scatter_row(src_ids, reorder_idx[partition_book.partid], src)
dst_ids = F.scatter_row(dst_ids, reorder_idx[partition_book.partid], dst)
# receive responses from remote machines.
if msgseq2pos is not None:
results = recv_responses(msgseq2pos)
for result in results:
src = result.global_src
dst = result.global_dst
src_ids = F.scatter_row(src_ids, reorder_idx[result.order_id], src)
dst_ids = F.scatter_row(dst_ids, reorder_idx[result.order_id], dst)
return src_ids, dst_ids
def find_edges(g, edge_ids):
""" Given an edge ID array, return the source and destination
node ID array ``s`` and ``d`` from a distributed graph.
``s[i]`` and ``d[i]`` are source and destination node ID for
edge ``eid[i]``.
Parameters
----------
g : DistGraph
The distributed graph.
edges : tensor
The edge ID array.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
"""
def issue_remove_req(edge_ids, order_id):
return EdgesRequest(edge_ids, order_id)
def local_access(local_g, partition_book, edge_ids):
return _find_edges(local_g, partition_book, edge_ids)
return _distributed_edge_access(g, edge_ids, issue_remove_req, local_access)
def in_subgraph(g, nodes): def in_subgraph(g, nodes):
"""Extract the subgraph containing only the in edges of the given nodes. """Extract the subgraph containing only the in edges of the given nodes.
...@@ -280,4 +416,5 @@ def in_subgraph(g, nodes): ...@@ -280,4 +416,5 @@ def in_subgraph(g, nodes):
return _distributed_access(g, nodes, issue_remote_req, local_access) return _distributed_access(g, nodes, issue_remote_req, local_access)
register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse) register_service(SAMPLING_SERVICE_ID, SamplingRequest, SubgraphResponse)
register_service(EDGES_SERVICE_ID, EdgesRequest, FindEdgeResponse)
register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse) register_service(INSUBGRAPH_SERVICE_ID, InSubgraphRequest, SubgraphResponse)
...@@ -2,7 +2,7 @@ import dgl ...@@ -2,7 +2,7 @@ import dgl
import unittest import unittest
import os import os
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors from dgl.distributed import sample_neighbors, find_edges
from dgl.distributed import partition_graph, load_partition, load_partition_book from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys import sys
import multiprocessing as mp import multiprocessing as mp
...@@ -30,6 +30,14 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -30,6 +30,14 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
dgl.distributed.exit_client() dgl.distributed.exit_client()
return sampled_graph return sampled_graph
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_find_edges", gpb=gpb)
u, v = find_edges(dist_graph, eids)
dgl.distributed.exit_client()
return u, v
def check_rpc_sampling(tmpdir, num_server): def check_rpc_sampling(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
...@@ -67,6 +75,34 @@ def check_rpc_sampling(tmpdir, num_server): ...@@ -67,6 +75,34 @@ def check_rpc_sampling(tmpdir, num_server):
assert np.array_equal( assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)) F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
def check_rpc_find_edges(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
g.readonly()
num_parts = num_server
partition_graph(g, 'test_find_edges', num_parts, tmpdir,
num_hops=1, part_method='metis', reshuffle=False)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1, 'test_find_edges'))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
u, v = g.find_edges(eids)
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids)
assert F.array_equal(u, du)
assert F.array_equal(v, dv)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling(): def test_rpc_sampling():
...@@ -80,7 +116,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server): ...@@ -80,7 +116,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
g.readonly() g.readonly()
num_parts = num_server num_parts = num_server
...@@ -220,3 +256,5 @@ if __name__ == "__main__": ...@@ -220,3 +256,5 @@ if __name__ == "__main__":
check_rpc_sampling_shuffle(Path(tmpdirname), 2) check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 2) check_rpc_sampling(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 1) check_rpc_sampling(Path(tmpdirname), 1)
check_rpc_find_edges(Path(tmpdirname), 2)
check_rpc_find_edges(Path(tmpdirname), 1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment