Unverified Commit c1e01b1d authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[Distributed] use alltoall fix to bypass gloo - alltoallv bug in distributed partitioning (#4311)

* Alltoall Fix to bypass gloo - alltoallv bug which is preventing further testing

1. Replaced alltoallv gloo wrapper call with alltoall message.
2. All the messages are padded to be of same length
3. Receiving side unpads the messages and continues processing.

* Code changes to address CI comments

1. Removed unused functions from gloo_wrapper.py
2. Changed the function signature of alltoallv_cpu_data as suggested.
3. Added docstring to include more description of the functionality inside alltoallv_cpu_data. Included more asserts to validate the assumptions.

* Changed the function name appropriately

Changed the function name from "alltoallv_cpu_data" to alltoallv_cpu which I believe is appropriate because underlying functionality is providing alltoallv which is basically alltoall_cpu + padding

* Added code and text to address the review comments.

1. Changed the function name to indicate the local use of this function.
2. Changed docstring to indicate the assumptions made by alltoallv_cpu function.

* Removed unused function from import statement

Removed unused/removed function from import statement.
parent 7f8e1cf2
...@@ -15,8 +15,8 @@ from utils import read_ntype_partition_files, read_json, get_node_types, \ ...@@ -15,8 +15,8 @@ from utils import read_ntype_partition_files, read_json, get_node_types, \
augment_edge_data, get_gnid_range_map, \ augment_edge_data, get_gnid_range_map, \
write_dgl_objects, write_metadata_json, get_ntype_featnames, \ write_dgl_objects, write_metadata_json, get_ntype_featnames, \
get_idranges get_idranges
from gloo_wrapper import alltoall_cpu_object_lst, alltoallv_cpu, \ from gloo_wrapper import allgather_sizes, gather_metadata_json,\
alltoall_cpu, allgather_sizes, gather_metadata_json alltoallv_cpu
from globalids import assign_shuffle_global_nids_nodes, \ from globalids import assign_shuffle_global_nids_nodes, \
assign_shuffle_global_nids_edges, \ assign_shuffle_global_nids_edges, \
get_shuffle_global_nids_edges get_shuffle_global_nids_edges
...@@ -140,8 +140,6 @@ def exchange_edge_data(rank, world_size, edge_data): ...@@ -140,8 +140,6 @@ def exchange_edge_data(rank, world_size, edge_data):
""" """
input_list = [] input_list = []
send_sizes = []
recv_sizes = []
start = timer() start = timer()
for i in np.arange(world_size): for i in np.arange(world_size):
send_idx = (edge_data[constants.OWNER_PROCESS] == i) send_idx = (edge_data[constants.OWNER_PROCESS] == i)
...@@ -152,23 +150,13 @@ def exchange_edge_data(rank, world_size, edge_data): ...@@ -152,23 +150,13 @@ def exchange_edge_data(rank, world_size, edge_data):
edge_data[constants.ETYPE_ID][send_idx == 1], \ edge_data[constants.ETYPE_ID][send_idx == 1], \
edge_data[constants.GLOBAL_EID][send_idx == 1])) edge_data[constants.GLOBAL_EID][send_idx == 1]))
if(filt_data.shape[0] <= 0): if(filt_data.shape[0] <= 0):
input_list.append(torch.empty((0,), dtype=torch.int64)) input_list.append(torch.empty((0,5), dtype=torch.int64))
send_sizes.append(torch.empty((0,), dtype=torch.int64))
else: else:
input_list.append(torch.from_numpy(filt_data)) input_list.append(torch.from_numpy(filt_data))
send_sizes.append(torch.tensor(filt_data.shape, dtype=torch.int64))
recv_sizes.append(torch.zeros((2,), dtype=torch.int64))
end = timer() end = timer()
dist.barrier () dist.barrier ()
start = timer() output_list = alltoallv_cpu(rank, world_size, input_list)
alltoall_cpu(rank, world_size, recv_sizes, send_sizes)
output_list = []
for s in recv_sizes:
output_list.append(torch.zeros(s.tolist(), dtype=torch.int64))
dist.barrier ()
alltoallv_cpu(rank, world_size, output_list, input_list)
end = timer() end = timer()
print('[Rank: ', rank, '] Time to send/rcv edge data: ', timedelta(seconds=end-start)) print('[Rank: ', rank, '] Time to send/rcv edge data: ', timedelta(seconds=end-start))
...@@ -197,6 +185,9 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map, ...@@ -197,6 +185,9 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map,
retrieved. retrieved.
d. After receiving the corresponding shuffle_global_nids these ids are added to the d. After receiving the corresponding shuffle_global_nids these ids are added to the
node_data and edge_data dictionaries node_data and edge_data dictionaries
This pipeline assumes all the input data in numpy format, except node/edge features which
are maintained as tensors throughout the various stages of the pipeline execution.
Parameters: Parameters:
----------- -----------
...@@ -246,6 +237,7 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map, ...@@ -246,6 +237,7 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map,
global_nid_per_rank = [] global_nid_per_rank = []
feat_name = feat_info[0] feat_name = feat_info[0]
feat_key = ntype_name+'/'+feat_name feat_key = ntype_name+'/'+feat_name
print('[Rank: ', rank, '] processing node feature: ', feat_key)
#compute the global_nid range for this node features #compute the global_nid range for this node features
type_nid_start = int(feat_info[1]) type_nid_start = int(feat_info[1])
...@@ -273,29 +265,20 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map, ...@@ -273,29 +265,20 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map,
local_idx_partid = local_idx[cond] local_idx_partid = local_idx[cond]
if (gnids_per_partid.shape[0] == 0): if (gnids_per_partid.shape[0] == 0):
node_feats_per_rank.append({feat_key : torch.empty((0,), dtype=torch.float)}) node_feats_per_rank.append(torch.empty((0,1), dtype=torch.float))
global_nid_per_rank.append({feat_key : torch.empty((0,), dtype=torch.int64)}) global_nid_per_rank.append(np.empty((0,1), dtype=np.int64))
else: else:
node_feats_per_rank.append({feat_key : node_feats[local_idx_partid]}) node_feats_per_rank.append(node_feats[local_idx_partid])
global_nid_per_rank.append({feat_key : gnids_per_partid}) global_nid_per_rank.append(torch.from_numpy(gnids_per_partid).type(torch.int64))
#features (and global nids) per rank to be sent out are ready #features (and global nids) per rank to be sent out are ready
#for transmission, perform alltoallv here. #for transmission, perform alltoallv here.
output_feat_list = alltoall_cpu_object_lst(rank, world_size, node_feats_per_rank) output_feat_list = alltoallv_cpu(rank, world_size, node_feats_per_rank)
output_feat_list[rank] = node_feats_per_rank[rank] output_nid_list = alltoallv_cpu(rank, world_size, global_nid_per_rank)
output_nid_list = alltoall_cpu_object_lst(rank, world_size, global_nid_per_rank)
output_nid_list[rank] = global_nid_per_rank[rank]
#stitch node_features together to form one large feature tensor #stitch node_features together to form one large feature tensor
own_node_features[feat_key] = [] own_node_features[feat_key] = torch.cat(output_feat_list)
own_global_nids[feat_key] = [] own_global_nids[feat_key] = torch.cat(output_nid_list).numpy()
for idx, x in enumerate(output_feat_list):
own_node_features[feat_key].append(x[feat_key])
own_global_nids[feat_key].append(output_nid_list[idx][feat_key])
for k in own_node_features.keys():
own_node_features[k] = torch.cat(own_node_features[k])
own_global_nids[k] = np.concatenate(own_global_nids[k])
end = timer() end = timer()
print('[Rank: ', rank, '] Total time for node feature exchange: ', timedelta(seconds = end - start)) print('[Rank: ', rank, '] Total time for node feature exchange: ', timedelta(seconds = end - start))
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import operator import operator
import itertools import itertools
import constants import constants
from gloo_wrapper import allgather_sizes, alltoall_cpu, alltoallv_cpu from gloo_wrapper import allgather_sizes, alltoallv_cpu
def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data): def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data):
""" """
...@@ -29,53 +29,27 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data): ...@@ -29,53 +29,27 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data):
from other processes. from other processes.
""" """
#build a list of sizes (lengths of lists) #build a list of sizes (lengths of lists)
sizes = [len(x) for x in global_nids_ranks] global_nids_ranks = [torch.from_numpy(x) for x in global_nids_ranks]
recv_nodes = alltoallv_cpu(rank, world_size, global_nids_ranks)
#compute total_nodes whose mappings should be resolved, between orig-node-id <-> global-id
total_nodes = np.sum(sizes)
if (total_nodes == 0):
print('Rank: ', rank, ' -- All mappings are present locally... No need for to send any info.')
return None
#determine the no. of global_node_ids to send and receive and perform alltoall
send_counts = list(torch.Tensor(sizes).type(dtype=torch.int64).chunk(world_size))
recv_counts = list(torch.zeros([world_size], dtype=torch.int64).chunk(world_size))
alltoall_cpu(rank, world_size, recv_counts, send_counts)
#allocate buffers to receive node-ids
recv_nodes = []
for i in recv_counts:
recv_nodes.append(torch.zeros(i.tolist(), dtype=torch.int64))
#form the outgoing message
send_nodes = []
for i in range(world_size):
send_nodes.append(torch.from_numpy(global_nids_ranks[i]).type(dtype=torch.int64))
#send-recieve messages
alltoallv_cpu(rank, world_size, recv_nodes, send_nodes)
# allocate buffers to receive global-ids
recv_shuffle_global_nids = []
for i in sizes:
recv_shuffle_global_nids.append(torch.zeros((i), dtype=torch.int64))
# Use node_data to lookup global id to send over. # Use node_data to lookup global id to send over.
send_nodes = [] send_nodes = []
for proc_i_nodes in recv_nodes: for proc_i_nodes in recv_nodes:
#list of node-ids to lookup #list of node-ids to lookup
global_nids = proc_i_nodes.numpy() if proc_i_nodes is not None:
if (len(global_nids) != 0): global_nids = proc_i_nodes.numpy()
common, ind1, ind2 = np.intersect1d(node_data[constants.GLOBAL_NID], global_nids, return_indices=True) if(len(global_nids) != 0):
shuffle_global_nids = node_data[constants.SHUFFLE_GLOBAL_NID][ind1] common, ind1, ind2 = np.intersect1d(node_data[constants.GLOBAL_NID], global_nids, return_indices=True)
send_nodes.append(torch.from_numpy(shuffle_global_nids).type(dtype=torch.int64)) shuffle_global_nids = node_data[constants.SHUFFLE_GLOBAL_NID][ind1]
send_nodes.append(torch.from_numpy(shuffle_global_nids).type(dtype=torch.int64))
else:
send_nodes.append(torch.empty((0), dtype=torch.int64))
else: else:
send_nodes.append(torch.empty((0), dtype=torch.int64)) send_nodes.append(torch.empty((0), dtype=torch.int64))
#send receive global-ids #send receive global-ids
alltoallv_cpu(rank, world_size, recv_shuffle_global_nids, send_nodes) recv_shuffle_global_nids = alltoallv_cpu(rank, world_size, send_nodes)
shuffle_global_nids = np.concatenate([x.numpy() if x is not None else [] for x in recv_shuffle_global_nids])
shuffle_global_nids = np.concatenate([x.numpy() for x in recv_shuffle_global_nids])
global_nids = np.concatenate([x for x in global_nids_ranks]) global_nids = np.concatenate([x for x in global_nids_ranks])
ret_val = np.column_stack([global_nids, shuffle_global_nids]) ret_val = np.column_stack([global_nids, shuffle_global_nids])
return ret_val return ret_val
...@@ -111,13 +85,13 @@ def get_shuffle_global_nids_edges(rank, world_size, edge_data, node_part_ids, no ...@@ -111,13 +85,13 @@ def get_shuffle_global_nids_edges(rank, world_size, edge_data, node_part_ids, no
global_nids_ranks = [] global_nids_ranks = []
for i in range(world_size): for i in range(world_size):
if (i == rank): if (i == rank):
global_nids_ranks.append(np.empty(shape=(0))) global_nids_ranks.append(np.empty(shape=(0), dtype=np.int64))
continue continue
#not_owned_nodes = part_ids[:,0][part_ids[:,1] == i] #not_owned_nodes = part_ids[:,0][part_ids[:,1] == i]
not_owned_node_ids = np.where(part_ids == i)[0] not_owned_node_ids = np.where(part_ids == i)[0]
if not_owned_node_ids.shape[0] == 0: if not_owned_node_ids.shape[0] == 0:
not_owned_nodes = np.empty(shape=(0)) not_owned_nodes = np.empty(shape=(0), dtype=np.int64)
else: else:
not_owned_nodes = global_nids[not_owned_node_ids] not_owned_nodes = global_nids[not_owned_node_ids]
global_nids_ranks.append(not_owned_nodes) global_nids_ranks.append(not_owned_nodes)
......
...@@ -39,7 +39,7 @@ def allgather_sizes(send_data, world_size): ...@@ -39,7 +39,7 @@ def allgather_sizes(send_data, world_size):
return rank_sizes return rank_sizes
def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list): def __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
""" """
Each process scatters list of input tensors to all processes in a cluster Each process scatters list of input tensors to all processes in a cluster
and return gathered list of tensors in output list. The tensors should have the same shape. and return gathered list of tensors in output list. The tensors should have the same shape.
...@@ -59,13 +59,22 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list): ...@@ -59,13 +59,22 @@ def alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
for i in range(world_size): for i in range(world_size):
dist.scatter(output_tensor_list[i], input_tensor_list if i == rank else [], src=i) dist.scatter(output_tensor_list[i], input_tensor_list if i == rank else [], src=i)
def alltoall_cpu_object_lst(rank, world_size, input_list): def alltoallv_cpu(rank, world_size, input_tensor_list):
""" """
Each process scatters list of input objects to all processes in a cluster Wrapper function to providing the alltoallv functionality by using underlying alltoall
and return gathered list of objects in output list. messaging primitive. This function, in its current implementation, supports exchanging
messages of arbitrary dimensions and is not tied to the user of this function.
This function pads all input tensors, except one, so that all the messages are of the same
size. Once the messages are padded, It first sends a vector whose first two elements are
1) actual message size along first dimension, and 2) Message size along first dimension
which is used for communication. The rest of the dimensions are assumed to be same across
all the input tensors. After receiving the message sizes, the receiving end will create buffers
of appropriate sizes. And then slices the received messages to remove the added padding, if any,
and returns to the caller.
Parameters Parameters:
---------- -----------
rank : int rank : int
The rank of current worker The rank of current worker
world_size : int world_size : int
...@@ -73,57 +82,62 @@ def alltoall_cpu_object_lst(rank, world_size, input_list): ...@@ -73,57 +82,62 @@ def alltoall_cpu_object_lst(rank, world_size, input_list):
input_tensor_list : List of tensor input_tensor_list : List of tensor
The tensors to exchange The tensors to exchange
Returns Returns:
------- --------
list: list of objects are received from other processes list :
This is the list of objects which are sent to the current process by list of tensors received from other processes during alltoall message
other processes as part of this exchange
"""
rcv_list = []
output_list = [None] * world_size
for i in range(world_size):
rcv_list.clear()
rcv_list.append(None)
if (i == rank):
dist.scatter_object_list(rcv_list, input_list, src = rank)
else:
send_list = [None] * world_size
dist.scatter_object_list(rcv_list, send_list, src = i)
output_list[i] = rcv_list[0]
return output_list
def alltoallv_cpu(rank, world_size, output_tensor_list, input_tensor_list):
""" """
Each process scatters list of input tensors to all processes in a cluster #ensure len of input_tensor_list is same as the world_size.
and return gathered list of tensors in output list. assert input_tensor_list != None
assert len(input_tensor_list) == world_size
Parameters
---------- #ensure that all the tensors in the input_tensor_list are of same size.
rank : int sizes = [list(x.size()) for x in input_tensor_list]
The rank of current worker for idx in range(1,len(sizes)):
world_size : int assert len(sizes[idx-1]) == len(sizes[idx]) #no. of dimensions should be same
The size of the entire assert input_tensor_list[idx-1].dtype == input_tensor_list[idx].dtype # dtype should be same
output_tensor_list : List of tensor assert sizes[idx-1][1:] == sizes[idx][1:] #except first dimension remaining dimensions should all be the same
The received tensors
input_tensor_list : List of tensor #decide how much to pad.
The tensors to exchange #always use the first-dimension for padding.
""" ll = [ x[0] for x in sizes ]
# send tensor to each target trainer using torch.distributed.isend
# isend is async #dims of the padding needed, if any
senders = [] #these dims are used for padding purposes.
for i in range(world_size): diff_dims = [ [np.amax(ll) - l[0]] + l[1:] for l in sizes ]
if i == rank:
output_tensor_list[i] = input_tensor_list[i].to(torch.device('cpu')) #pad the actual message
input_tensor_list = [torch.cat((x, torch.zeros(diff_dims[idx]).type(x.dtype))) for idx, x in enumerate(input_tensor_list)]
#send useful message sizes to all
send_counts = []
recv_counts = []
for idx in range(world_size):
#send a vector, of atleast 3 elements, [a, b, ....] where
#a = useful message dim, b = actual message outgoing message size along the first dimension
#and remaining elements are the remaining dimensions of the tensor
send_counts.append(torch.from_numpy(np.array([sizes[idx][0]] + [np.amax(ll)] + sizes[idx][1:] )).type(torch.int64))
recv_counts.append(torch.zeros((1 + len(sizes[idx])), dtype=torch.int64))
__alltoall_cpu(rank, world_size, recv_counts, send_counts)
#allocate buffers for receiving message
output_tensor_list = []
recv_counts = [ tsize.numpy() for tsize in recv_counts]
for idx, tsize in enumerate(recv_counts):
output_tensor_list.append(torch.zeros(tuple(tsize[1:])).type(input_tensor_list[idx].dtype))
#send actual message itself.
__alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list)
#extract un-padded message from the output_tensor_list and return it
return_vals = []
for s, t in zip(recv_counts, output_tensor_list):
if s[0] == 0:
return_vals.append(None)
else: else:
sender = dist.isend(input_tensor_list[i].to(torch.device('cpu')), dst=i) return_vals.append(t[0:s[0]])
senders.append(sender) return return_vals
for i in range(world_size):
if i != rank:
dist.recv(output_tensor_list[i], src=i)
torch.distributed.barrier()
def gather_metadata_json(metadata, rank, world_size): def gather_metadata_json(metadata, rank, world_size):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment