Unverified Commit 34426a98 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Distributed heterograph training (#3069)



* support hetero RGCN.

* fix.

* simplify code.

* sample_neighbors return heterograph directly.

* avoid using to_heterogeneous.

* compute canonical etypes in advance.

* fix tests.

* fix.

* fix distributed data loader for heterograph.

* use NodeDataLoader.

* fix bugs in partitioning on heterogeneous graphs.

* fix lint.

* fix tests.

* fix.

* fix.

* fix bugs.

* fix tests.

* fix.

* enable coo for distributed.

* fix.

* fix.

* fix.

* fix.

* fix.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-71-112.ec2.internal>
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
parent 905c0aa5
...@@ -21,15 +21,125 @@ from torch.multiprocessing import Queue ...@@ -21,15 +21,125 @@ from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
from dgl import nn as dglnn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.distributed import DistDataLoader from dgl.distributed import DistDataLoader
from functools import partial from functools import partial
from dgl.nn import RelGraphConv
import tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
for rel in rel_names
})
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def forward(self, g, inputs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
if g.is_block:
inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
inputs_src = inputs_dst = inputs
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
""" Entity classification class for RGCN """ Entity classification class for RGCN
Parameters Parameters
...@@ -42,8 +152,8 @@ class EntityClassify(nn.Module): ...@@ -42,8 +152,8 @@ class EntityClassify(nn.Module):
Hidden dim size. Hidden dim size.
out_dim : int out_dim : int
Output dim size. Output dim size.
num_rels : int rel_names : list of str
Numer of relation types. A list of relation names.
num_bases : int num_bases : int
Number of bases. If is none, use number of relations. Number of bases. If is none, use number of relations.
num_hidden_layers : int num_hidden_layers : int
...@@ -52,51 +162,43 @@ class EntityClassify(nn.Module): ...@@ -52,51 +162,43 @@ class EntityClassify(nn.Module):
Dropout Dropout
use_self_loop : bool use_self_loop : bool
Use self loop if True, default False. Use self loop if True, default False.
low_mem : bool
True to use low memory implementation of relation message passing function
trade speed with memory consumption
""" """
def __init__(self, def __init__(self,
device, device,
h_dim, h_dim,
out_dim, out_dim,
num_rels, rel_names,
num_bases=None, num_bases=None,
num_hidden_layers=1, num_hidden_layers=1,
dropout=0, dropout=0,
use_self_loop=False, use_self_loop=False,
low_mem=False,
layer_norm=False): layer_norm=False):
super(EntityClassify, self).__init__() super(EntityClassify, self).__init__()
self.device = device self.device = device
self.h_dim = h_dim self.h_dim = h_dim
self.out_dim = out_dim self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = None if num_bases < 0 else num_bases self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.dropout = dropout self.dropout = dropout
self.use_self_loop = use_self_loop self.use_self_loop = use_self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm self.layer_norm = layer_norm
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConv( self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.num_rels, "basis", self.h_dim, self.h_dim, rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
low_mem=self.low_mem, dropout=self.dropout)) dropout=self.dropout))
# h2h # h2h
for idx in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
self.layers.append(RelGraphConv( self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.num_rels, "basis", self.h_dim, self.h_dim, rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
low_mem=self.low_mem, dropout=self.dropout)) dropout=self.dropout))
# h2o # h2o
self.layers.append(RelGraphConv( self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, self.num_rels, "basis", self.h_dim, self.out_dim, rel_names,
self.num_bases, activation=None, self.num_bases, activation=None, self_loop=self.use_self_loop))
self_loop=self.use_self_loop,
low_mem=self.low_mem))
def forward(self, blocks, feats, norm=None): def forward(self, blocks, feats, norm=None):
if blocks is None: if blocks is None:
...@@ -105,7 +207,7 @@ class EntityClassify(nn.Module): ...@@ -105,7 +207,7 @@ class EntityClassify(nn.Module):
h = feats h = feats
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
block = block.to(self.device) block = block.to(self.device)
h = layer(block, h, block.edata[dgl.ETYPE], block.edata['norm']) h = layer(block, h)
return h return h
def init_emb(shape, dtype): def init_emb(shape, dtype):
...@@ -182,27 +284,23 @@ class DistEmbedLayer(nn.Module): ...@@ -182,27 +284,23 @@ class DistEmbedLayer(nn.Module):
self.node_embeds[ntype] = th.nn.Embedding(g.number_of_nodes(ntype), self.embed_size) self.node_embeds[ntype] = th.nn.Embedding(g.number_of_nodes(ntype), self.embed_size)
nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0) nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0)
def forward(self, node_ids, ntype_ids): def forward(self, node_ids):
"""Forward computation """Forward computation
Parameters Parameters
---------- ----------
node_ids : Tensor node_ids : dict of Tensor
node ids to generate embedding for. node ids to generate embedding for.
ntype_ids : Tensor
node type ids
Returns Returns
------- -------
tensor tensor
embeddings as the input of the next layer embeddings as the input of the next layer
""" """
embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.dev_id) embeds = {}
for ntype_id in th.unique(ntype_ids).tolist(): for ntype in node_ids:
ntype = self.ntype_id_map[int(ntype_id)]
loc = ntype_ids == ntype_id
if self.feat_name in self.g.nodes[ntype].data: if self.feat_name in self.g.nodes[ntype].data:
embeds[loc] = self.node_projs[ntype](self.g.nodes[ntype].data[self.feat_name][node_ids[ntype_ids == ntype_id]].to(self.dev_id)) embeds[ntype] = self.node_projs[ntype](self.g.nodes[ntype].data[self.feat_name][node_ids[ntype]].to(self.dev_id))
else: else:
embeds[loc] = self.node_embeds[ntype](node_ids[ntype_ids == ntype_id]).to(self.dev_id) embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(self.dev_id)
return embeds return embeds
def compute_acc(results, labels): def compute_acc(results, labels):
...@@ -212,14 +310,6 @@ def compute_acc(results, labels): ...@@ -212,14 +310,6 @@ def compute_acc(results, labels):
labels = labels.long() labels = labels.long()
return (results == labels).float().sum() / len(results) return (results == labels).float().sum() / len(results)
def gen_norm(g):
_, v, eid = g.all_edges(form='all')
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
degrees = count[inverse_index]
norm = th.ones(eid.shape[0], device=eid.device) / degrees
norm = norm.unsqueeze(1)
g.edata['norm'] = norm
def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_nid, all_test_nid): def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_nid, all_test_nid):
model.eval() model.eval()
embed_layer.eval() embed_layer.eval()
...@@ -231,11 +321,12 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni ...@@ -231,11 +321,12 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni
with th.no_grad(): with th.no_grad():
th.cuda.empty_cache() th.cuda.empty_cache()
for sample_data in tqdm.tqdm(eval_loader): for sample_data in tqdm.tqdm(eval_loader):
seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
for block in blocks: seeds = seeds['paper']
gen_norm(block) feats = embed_layer(input_nodes)
feats = embed_layer(blocks[0].srcdata[dgl.NID], blocks[0].srcdata[dgl.NTYPE])
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1
logits = logits['paper']
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
assert np.all(seeds.numpy() < g.number_of_nodes('paper')) assert np.all(seeds.numpy() < g.number_of_nodes('paper'))
eval_seeds.append(seeds.cpu().detach()) eval_seeds.append(seeds.cpu().detach())
...@@ -248,11 +339,12 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni ...@@ -248,11 +339,12 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni
with th.no_grad(): with th.no_grad():
th.cuda.empty_cache() th.cuda.empty_cache()
for sample_data in tqdm.tqdm(test_loader): for sample_data in tqdm.tqdm(test_loader):
seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
for block in blocks: seeds = seeds['paper']
gen_norm(block) feats = embed_layer(input_nodes)
feats = embed_layer(blocks[0].srcdata[dgl.NID], blocks[0].srcdata[dgl.NTYPE])
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1
logits = logits['paper']
test_logits.append(logits.cpu().detach()) test_logits.append(logits.cpu().detach())
assert np.all(seeds.numpy() < g.number_of_nodes('paper')) assert np.all(seeds.numpy() < g.number_of_nodes('paper'))
test_seeds.append(seeds.cpu().detach()) test_seeds.append(seeds.cpu().detach())
...@@ -267,90 +359,36 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni ...@@ -267,90 +359,36 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni
else: else:
return -1, -1 return -1, -1
class NeighborSampler:
"""Neighbor sampler
Parameters
----------
g : DGLHeterograph
Full graph
target_idx : tensor
The target training node IDs in g
fanouts : list of int
Fanout of each hop starting from the seed nodes. If a fanout is None,
sample full neighbors.
"""
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
"""Do neighbor sample
Parameters
----------
seeds :
Seed nodes
Returns
-------
tensor
Seed nodes, also known as target nodes
blocks
Sampled subgraphs
"""
blocks = []
etypes = []
norms = []
ntypes = []
seeds = th.LongTensor(np.asarray(seeds))
gpb = self.g.get_partition_book()
# We need to map the per-type node IDs to homogeneous IDs.
cur = gpb.map_to_homo_nid(seeds, 'paper')
for fanout in self.fanouts:
# For a heterogeneous input graph, the returned frontier is stored in
# the homogeneous graph format.
frontier = self.sample_neighbors(self.g, cur, fanout, replace=False)
block = dgl.to_block(frontier, cur)
cur = block.srcdata[dgl.NID]
block.edata[dgl.EID] = frontier.edata[dgl.EID]
# Map the homogeneous edge Ids to their edge type.
block.edata[dgl.ETYPE], block.edata[dgl.EID] = gpb.map_to_per_etype(block.edata[dgl.EID])
# Map the homogeneous node Ids to their node types and per-type Ids.
block.srcdata[dgl.NTYPE], block.srcdata[dgl.NID] = gpb.map_to_per_ntype(block.srcdata[dgl.NID])
block.dstdata[dgl.NTYPE], block.dstdata[dgl.NID] = gpb.map_to_per_ntype(block.dstdata[dgl.NID])
blocks.insert(0, block)
return seeds, blocks
def run(args, device, data): def run(args, device, data):
g, num_classes, train_nid, val_nid, test_nid, labels, all_val_nid, all_test_nid = data g, num_classes, train_nid, val_nid, test_nid, labels, all_val_nid, all_test_nid = data
num_rels = len(g.etypes)
fanouts = [int(fanout) for fanout in args.fanout.split(',')] fanouts = [int(fanout) for fanout in args.fanout.split(',')]
val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(',')] val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(',')]
sampler = NeighborSampler(g, fanouts, dgl.distributed.sample_neighbors)
# Create DataLoader for constructing blocks sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
dataloader = DistDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
dataset=train_nid, g,
{'paper': train_nid},
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False)
valid_sampler = NeighborSampler(g, val_fanouts, dgl.distributed.sample_neighbors) valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
# Create DataLoader for constructing blocks valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = DistDataLoader( g,
dataset=val_nid, {'paper': val_nid},
valid_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=valid_sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
test_sampler = NeighborSampler(g, [-1] * args.n_layers, dgl.distributed.sample_neighbors) test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
# Create DataLoader for constructing blocks test_dataloader = dgl.dataloading.NodeDataLoader(
test_dataloader = DistDataLoader( g,
dataset=test_nid, {'paper': test_nid},
test_sampler,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
collate_fn=test_sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False)
...@@ -364,12 +402,11 @@ def run(args, device, data): ...@@ -364,12 +402,11 @@ def run(args, device, data):
model = EntityClassify(device, model = EntityClassify(device,
args.n_hidden, args.n_hidden,
num_classes, num_classes,
num_rels, g.etypes,
num_bases=args.n_bases, num_bases=args.n_bases,
num_hidden_layers=args.n_layers-2, num_hidden_layers=args.n_layers-2,
dropout=args.dropout, dropout=args.dropout,
use_self_loop=args.use_self_loop, use_self_loop=args.use_self_loop,
low_mem=args.low_mem,
layer_norm=args.layer_norm) layer_norm=args.layer_norm)
model = model.to(device) model = model.to(device)
...@@ -442,22 +479,23 @@ def run(args, device, data): ...@@ -442,22 +479,23 @@ def run(args, device, data):
# blocks. # blocks.
step_time = [] step_time = []
for step, sample_data in enumerate(dataloader): for step, sample_data in enumerate(dataloader):
seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
seeds = seeds['paper']
number_train += seeds.shape[0] number_train += seeds.shape[0]
number_input += np.sum([blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]) number_input += np.sum([blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes])
tic_step = time.time() tic_step = time.time()
sample_time += tic_step - start sample_time += tic_step - start
sample_t.append(tic_step - start) sample_t.append(tic_step - start)
for block in blocks: feats = embed_layer(input_nodes)
gen_norm(block)
feats = embed_layer(blocks[0].srcdata[dgl.NID], blocks[0].srcdata[dgl.NTYPE])
label = labels[seeds].to(device) label = labels[seeds].to(device)
copy_time = time.time() copy_time = time.time()
feat_copy_t.append(copy_time - tic_step) feat_copy_t.append(copy_time - tic_step)
# forward # forward
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1
logits = logits['paper']
loss = F.cross_entropy(logits, label) loss = F.cross_entropy(logits, label)
forward_end = time.time() forward_end = time.time()
......
...@@ -390,6 +390,7 @@ def zerocopy_to_numpy(arr): ...@@ -390,6 +390,7 @@ def zerocopy_to_numpy(arr):
return arr.asnumpy() return arr.asnumpy()
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
np_data = np.asarray(np_data, order='C')
return mx.nd.from_numpy(np_data, zero_copy=True) return mx.nd.from_numpy(np_data, zero_copy=True)
def zerocopy_to_dgl_ndarray(arr): def zerocopy_to_dgl_ndarray(arr):
......
...@@ -361,7 +361,8 @@ class Collator(ABC): ...@@ -361,7 +361,8 @@ class Collator(ABC):
def _prepare_tensor_dict(g, data, name, is_distributed): def _prepare_tensor_dict(g, data, name, is_distributed):
if is_distributed: if is_distributed:
x = F.tensor(next(iter(data.values()))) x = F.tensor(next(iter(data.values())))
return {k: F.copy_to(F.astype(v, F.dtype(x)), F.context(x)) for k, v in data.items()} return {k: F.copy_to(F.astype(F.tensor(v), F.dtype(x)), F.context(x)) \
for k, v in data.items()}
else: else:
return utils.prepare_tensor_dict(g, data, name) return utils.prepare_tensor_dict(g, data, name)
......
...@@ -64,7 +64,7 @@ class DistDataLoader: ...@@ -64,7 +64,7 @@ class DistDataLoader:
Parameters Parameters
---------- ----------
dataset: a tensor dataset: a tensor
A tensor of node IDs or edge IDs. Tensors of node IDs or edge IDs.
batch_size: int batch_size: int
The number of samples per batch to load. The number of samples per batch to load.
shuffle: bool, optional shuffle: bool, optional
...@@ -127,7 +127,8 @@ class DistDataLoader: ...@@ -127,7 +127,8 @@ class DistDataLoader:
self.shuffle = shuffle self.shuffle = shuffle
self.is_closed = False self.is_closed = False
self.dataset = F.tensor(dataset) self.dataset = dataset
self.data_idx = F.arange(0, len(dataset))
self.expected_idxs = len(dataset) // self.batch_size self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0: if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1 self.expected_idxs += 1
...@@ -176,7 +177,7 @@ class DistDataLoader: ...@@ -176,7 +177,7 @@ class DistDataLoader:
def __iter__(self): def __iter__(self):
if self.shuffle: if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset) self.data_idx = F.rand_shuffle(self.data_idx)
self.recv_idxs = 0 self.recv_idxs = 0
self.current_pos = 0 self.current_pos = 0
self.num_pending = 0 self.num_pending = 0
...@@ -205,6 +206,7 @@ class DistDataLoader: ...@@ -205,6 +206,7 @@ class DistDataLoader:
end_pos = len(self.dataset) end_pos = len(self.dataset)
else: else:
end_pos = self.current_pos + self.batch_size end_pos = self.current_pos + self.batch_size
ret = self.dataset[self.current_pos:end_pos] idx = self.data_idx[self.current_pos:end_pos].tolist()
ret = [self.dataset[i] for i in idx]
self.current_pos = end_pos self.current_pos = end_pos
return ret return ret
...@@ -296,7 +296,7 @@ class DistGraphServer(KVServer): ...@@ -296,7 +296,7 @@ class DistGraphServer(KVServer):
''' '''
def __init__(self, server_id, ip_config, num_servers, def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False, num_clients, part_config, disable_shared_mem=False,
graph_format='csc'): graph_format=('csc', 'coo')):
super(DistGraphServer, self).__init__(server_id=server_id, super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
...@@ -482,6 +482,25 @@ class DistGraph: ...@@ -482,6 +482,25 @@ class DistGraph:
self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)} self._ntype_map = {ntype:i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {etype:i for i, etype in enumerate(self.etypes)} self._etype_map = {etype:i for i, etype in enumerate(self.etypes)}
# Get canonical edge types.
# TODO(zhengda) this requires the server to store the graph with coo format.
eid = []
for etype in self.etypes:
type_eid = F.zeros((1,), F.int64, F.cpu())
eid.append(self._gpb.map_to_homo_eid(type_eid, etype))
eid = F.cat(eid, 0)
src, dst = dist_find_edges(self, eid)
src_tids, _ = self._gpb.map_to_per_ntype(src)
dst_tids, _ = self._gpb.map_to_per_ntype(dst)
self._canonical_etypes = []
etype_ids = F.arange(0, len(self.etypes))
for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):
src_tid = F.as_scalar(src_tid)
etype_id = F.as_scalar(etype_id)
dst_tid = F.as_scalar(dst_tid)
self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id],
self.ntypes[dst_tid]))
def _init(self): def _init(self):
self._client = get_kvstore() self._client = get_kvstore()
assert self._client is not None, \ assert self._client is not None, \
...@@ -576,7 +595,7 @@ class DistGraph: ...@@ -576,7 +595,7 @@ class DistGraph:
int int
""" """
# TODO(da?): describe when self._g is None and idtype shouldn't be called. # TODO(da?): describe when self._g is None and idtype shouldn't be called.
return self._g.idtype return F.int64
@property @property
def device(self): def device(self):
...@@ -598,7 +617,7 @@ class DistGraph: ...@@ -598,7 +617,7 @@ class DistGraph:
Device context object Device context object
""" """
# TODO(da?): describe when self._g is None and device shouldn't be called. # TODO(da?): describe when self._g is None and device shouldn't be called.
return self._g.device return F.cpu()
@property @property
def ntypes(self): def ntypes(self):
...@@ -635,6 +654,42 @@ class DistGraph: ...@@ -635,6 +654,42 @@ class DistGraph:
# Currently, we only support a graph with one edge type. # Currently, we only support a graph with one edge type.
return self._gpb.etypes return self._gpb.etypes
@property
def canonical_etypes(self):
"""Return all the canonical edge types in the graph.
A canonical edge type is a string triplet ``(str, str, str)``
for source node type, edge type and destination node type.
Returns
-------
list[(str, str, str)]
All the canonical edge type triplets in a list.
Notes
-----
DGL internally assigns an integer ID for each edge type. The returned
edge type names are sorted according to their IDs.
See Also
--------
etypes
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> g = DistGraph("test")
>>> g.canonical_etypes
[('user', 'follows', 'user'),
('user', 'follows', 'game'),
('user', 'plays', 'game')]
"""
return self._canonical_etypes
def get_ntype_id(self, ntype): def get_ntype_id(self, ntype):
"""Return the ID of the given node type. """Return the ID of the given node type.
......
...@@ -770,16 +770,20 @@ class RangePartitionBook(GraphPartitionBook): ...@@ -770,16 +770,20 @@ class RangePartitionBook(GraphPartitionBook):
""" """
ids = utils.toindex(ids).tousertensor() ids = utils.toindex(ids).tousertensor()
partids = self.nid2partid(ids, ntype) partids = self.nid2partid(ids, ntype)
end_diff = F.tensor(self._typed_max_node_ids[ntype])[partids] - ids typed_max_nids = F.zerocopy_from_numpy(self._typed_max_node_ids[ntype])
return F.tensor(self._typed_nid_range[ntype][:, 1])[partids] - end_diff end_diff = F.gather_row(typed_max_nids, partids) - ids
typed_nid_range = F.zerocopy_from_numpy(self._typed_nid_range[ntype][:, 1])
return F.gather_row(typed_nid_range, partids) - end_diff
def map_to_homo_eid(self, ids, etype): def map_to_homo_eid(self, ids, etype):
"""Map per-edge-type IDs to global edge IDs in the homoenegeous format. """Map per-edge-type IDs to global edge IDs in the homoenegeous format.
""" """
ids = utils.toindex(ids).tousertensor() ids = utils.toindex(ids).tousertensor()
partids = self.eid2partid(ids, etype) partids = self.eid2partid(ids, etype)
end_diff = F.tensor(self._typed_max_edge_ids[etype][partids]) - ids typed_max_eids = F.zerocopy_from_numpy(self._typed_max_edge_ids[etype])
return F.tensor(self._typed_eid_range[etype][:, 1])[partids] - end_diff end_diff = F.gather_row(typed_max_eids, partids) - ids
typed_eid_range = F.zerocopy_from_numpy(self._typed_eid_range[etype][:, 1])
return F.gather_row(typed_eid_range, partids) - end_diff
def nid2partid(self, nids, ntype='_N'): def nid2partid(self, nids, ntype='_N'):
"""From global node IDs to partition IDs """From global node IDs to partition IDs
......
...@@ -5,7 +5,7 @@ from .rpc import Request, Response, send_requests_to_machine, recv_responses ...@@ -5,7 +5,7 @@ from .rpc import Request, Response, send_requests_to_machine, recv_responses
from ..sampling import sample_neighbors as local_sample_neighbors from ..sampling import sample_neighbors as local_sample_neighbors
from ..subgraph import in_subgraph as local_in_subgraph from ..subgraph import in_subgraph as local_in_subgraph
from .rpc import register_service from .rpc import register_service
from ..convert import graph from ..convert import graph, heterograph
from ..base import NID, EID from ..base import NID, EID
from ..utils import toindex from ..utils import toindex
from .. import backend as F from .. import backend as F
...@@ -337,19 +337,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -337,19 +337,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
Node/edge features are not preserved. The original IDs of Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph. the sampled edges are stored as the `dgl.EID` feature in the returned graph.
This version provides an experimental support for heterogeneous graphs. For heterogeneous graphs, ``nodes`` is a dictionary whose key is node type
When the input graph is heterogeneous, the sampled subgraph is still stored in and the value is type-specific node IDs.
the homogeneous graph format. That is, all nodes and edges are assigned with
unique IDs (in contrast, we typically use a type name and a node/edge ID to
identify a node or an edge in ``DGLGraph``). We refer to this type of IDs
as *homogeneous ID*.
Users can use :func:`dgl.distributed.GraphPartitionBook.map_to_per_ntype`
and :func:`dgl.distributed.GraphPartitionBook.map_to_per_etype`
to identify their node/edge types and node/edge IDs of that type.
For heterogeneous graphs, ``nodes`` can be a dictionary whose key is node type
and the value is type-specific node IDs; ``nodes`` can also be a tensor of
*homogeneous ID*.
Parameters Parameters
---------- ----------
...@@ -388,7 +377,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -388,7 +377,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
A sampled subgraph containing only the sampled neighboring edges. It is on CPU. A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
""" """
gpb = g.get_partition_book() gpb = g.get_partition_book()
if isinstance(nodes, dict): if len(gpb.etypes) > 1:
assert isinstance(nodes, dict)
homo_nids = [] homo_nids = []
for ntype in nodes: for ntype in nodes:
assert ntype in g.ntypes, 'The sampled node type does not exist in the input graph' assert ntype in g.ntypes, 'The sampled node type does not exist in the input graph'
...@@ -398,13 +388,45 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -398,13 +388,45 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
typed_nodes = toindex(nodes[ntype]).tousertensor() typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype)) homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
nodes = F.cat(homo_nids, 0) nodes = F.cat(homo_nids, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir, return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace) prob=prob, replace=replace)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
return _sample_neighbors(local_g, partition_book, local_nids, return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace) fanout, edge_dir, prob, replace)
return _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
else:
return frontier
def _distributed_edge_access(g, edges, issue_remote_req, local_access): def _distributed_edge_access(g, edges, issue_remote_req, local_access):
"""A routine that fetches local edges from distributed graph. """A routine that fetches local edges from distributed graph.
......
...@@ -55,7 +55,8 @@ def create_random_graph(n): ...@@ -55,7 +55,8 @@ def create_random_graph(n):
def run_server(graph_name, server_id, server_count, num_clients, shared_mem): def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients, g = DistGraphServer(server_id, "kv_ip_config.txt", server_count, num_clients,
'/tmp/dist_graph/{}.json'.format(graph_name), '/tmp/dist_graph/{}.json'.format(graph_name),
disable_shared_mem=not shared_mem) disable_shared_mem=not shared_mem,
graph_format=['csc', 'coo'])
print('start server', server_id) print('start server', server_id)
g.start() g.start()
...@@ -469,6 +470,13 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): ...@@ -469,6 +470,13 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
for etype in num_edges: for etype in num_edges:
assert etype in g.etypes assert etype in g.etypes
assert num_edges[etype] == g.number_of_edges(etype) assert num_edges[etype] == g.number_of_edges(etype)
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
for i, etype in enumerate(g.canonical_etypes):
assert etype[0] == etypes[i][0]
assert etype[1] == etypes[i][1]
assert etype[2] == etypes[i][2]
assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes]) assert g.number_of_nodes() == sum([num_nodes[ntype] for ntype in num_nodes])
assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges]) assert g.number_of_edges() == sum([num_edges[etype] for etype in num_edges])
...@@ -584,7 +592,6 @@ def test_server_client(): ...@@ -584,7 +592,6 @@ def test_server_client():
check_server_client(True, 1, 1) check_server_client(True, 1, 1)
check_server_client(False, 1, 1) check_server_client(False, 1, 1)
check_server_client(True, 2, 2) check_server_client(True, 2, 2)
check_server_client(False, 2, 2)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
......
...@@ -16,7 +16,7 @@ from scipy import sparse as spsp ...@@ -16,7 +16,7 @@ from scipy import sparse as spsp
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format='csc'): def start_server(rank, tmpdir, disable_shared_mem, graph_name, graph_format=['csc', 'coo']):
g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1, g = DistGraphServer(rank, "rpc_ip_config.txt", 1, 1,
tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem, tmpdir / (graph_name + '.json'), disable_shared_mem=disable_shared_mem,
graph_format=graph_format) graph_format=graph_format)
...@@ -284,7 +284,6 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -284,7 +284,6 @@ def start_hetero_sample_client(rank, tmpdir, disable_shared_mem):
try: try:
nodes = {'n3': [0, 10, 99, 66, 124, 208]} nodes = {'n3': [0, 10, 99, 66, 124, 208]}
sampled_graph = sample_neighbors(dist_graph, nodes, 3) sampled_graph = sample_neighbors(dist_graph, nodes, 3)
nodes = gpb.map_to_homo_nid(nodes['n3'], 'n3')
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
...@@ -320,47 +319,36 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server): ...@@ -320,47 +319,36 @@ def check_rpc_hetero_sampling_shuffle(tmpdir, num_server):
for p in pserver_list: for p in pserver_list:
p.join() p.join()
orig_nid_map = F.zeros((g.number_of_nodes(),), dtype=F.int64) orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = F.zeros((g.number_of_edges(),), dtype=F.int64) orig_eid_map = {etype: F.zeros((g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
F.scatter_row_inplace(orig_nid_map, part.ndata[dgl.NID], part.ndata['orig_id']) ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
F.scatter_row_inplace(orig_eid_map, part.edata[dgl.EID], part.edata['orig_id']) for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
src, dst = block.edges() F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
# These are global Ids after shuffling. F.boolean_mask(part.ndata['orig_id'], idx))
shuffled_src = F.gather_row(block.srcdata[dgl.NID], src) etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
shuffled_dst = F.gather_row(block.dstdata[dgl.NID], dst) for etype_id, etype in enumerate(g.etypes):
shuffled_eid = block.edata[dgl.EID] idx = etype_ids == etype_id
# Get node/edge types. F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
etype, _ = gpb.map_to_per_etype(shuffled_eid) F.boolean_mask(part.edata['orig_id'], idx))
src_type, _ = gpb.map_to_per_ntype(shuffled_src)
dst_type, _ = gpb.map_to_per_ntype(shuffled_dst) for src_type, etype, dst_type in block.canonical_etypes:
etype = F.asnumpy(etype) src, dst = block.edges(etype=etype)
src_type = F.asnumpy(src_type) # These are global Ids after shuffling.
dst_type = F.asnumpy(dst_type) shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
# These are global Ids in the original graph. shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
orig_src = F.asnumpy(F.gather_row(orig_nid_map, shuffled_src)) shuffled_eid = block.edges[etype].data[dgl.EID]
orig_dst = F.asnumpy(F.gather_row(orig_nid_map, shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map, shuffled_eid)) orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
etype_map = {g.get_etype_id(etype):etype for etype in g.etypes} orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))
etype_to_eptype = {g.get_etype_id(etype):(src_ntype, dst_ntype) for src_ntype, etype, dst_ntype in g.canonical_etypes}
for e in np.unique(etype):
src_t = src_type[etype == e]
dst_t = dst_type[etype == e]
assert np.all(src_t == src_t[0])
assert np.all(dst_t == dst_t[0])
# Check the node Ids and edge Ids. # Check the node Ids and edge Ids.
orig_src1, orig_dst1 = g.find_edges(orig_eid[etype == e], etype=etype_map[e]) orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src[etype == e]) assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst[etype == e]) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
# Check the node types.
src_ntype, dst_ntype = etype_to_eptype[e]
assert np.all(src_t == g.get_ntype_id(src_ntype))
assert np.all(dst_t == g.get_ntype_id(dst_ntype))
# Wait non shared memory graph store # Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
......
...@@ -41,7 +41,8 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients): ...@@ -41,7 +41,8 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl import dgl
print('server: #clients=' + str(num_clients)) print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients, g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem) tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem,
graph_format=['csc', 'coo'])
g.start() g.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment