Unverified Commit 208d9670 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1429 from NVIDIA/update_spatial_bottleneck

Bug fixes, perf improvements
parents a29a698f f687e7fa
...@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_ ...@@ -152,16 +152,6 @@ def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_
sb[n].copy_(b) sb[n].copy_(b)
return spatial_bottleneck return spatial_bottleneck
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False): def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):
assert(explicit_nhwc), "Only tested for explicit nhwc" assert(explicit_nhwc), "Only tested for explicit nhwc"
...@@ -228,15 +218,29 @@ def main(): ...@@ -228,15 +218,29 @@ def main():
#print_bottleneck_p_and_b(gt_bottleneck) #print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck) #print_bottleneck_p_and_b(spatial_bottleneck)
group_size = world_size
group = rank // group_size
ranks = [group*group_size+i for i in range(group_size)]
rank_in_group = rank % group_size
spatial_group_size = world_size spatial_group_size = world_size
spatial_communicator = None spatial_communicator = None
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024) peer_pool = PeerMemoryPool(64*1024*1024, 2*1024*1024, ranks)
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, ranks, rank_in_group, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator) #halex = HaloExchangerAllGather(ranks, rank_in_group)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator) #halex = HaloExchangerSendRecv(ranks, rank_in_group)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1) halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals))) #print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding # Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize() #torch.cuda.synchronize()
......
...@@ -9,17 +9,23 @@ import peer_memory_cuda as pm ...@@ -9,17 +9,23 @@ import peer_memory_cuda as pm
# NB! This is only useful for performance testing. # NB! This is only useful for performance testing.
# NB! Do not use for actual production runs # NB! Do not use for actual production runs
class HaloExchanger(object): class HaloExchanger(object):
def __init__(self, spatial_group_size, rank): def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream() self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream() self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream() self.stream3 = torch.cuda.Stream()
spatial_rank = rank % spatial_group_size self.group_size = len(ranks)
self.left_zero = True if spatial_rank == 0 else False self.ranks = ranks
self.right_zero = True if spatial_rank == spatial_group_size - 1 else False self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger): class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(spatial_group_size, rank) super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None: if left_input_halo is None:
...@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger): ...@@ -29,10 +35,9 @@ class HaloExchangerNoComm(HaloExchanger):
right_input_halo.copy_(left_output_halo) right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger): class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(spatial_group_size, rank) super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
self.spatial_group_size = spatial_group_size # self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.local_rank = rank % spatial_group_size
self.comm = comm self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
...@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger): ...@@ -40,11 +45,11 @@ class HaloExchangerAllGather(HaloExchanger):
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device) send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo) send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo) send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.spatial_group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device) all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.spatial_group_size)] all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True) torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:] ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:] ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
if left_input_halo is None: if left_input_halo is None:
if self.left_zero: if self.left_zero:
ag_left_input_halo.zero_() ag_left_input_halo.zero_()
...@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger): ...@@ -62,35 +67,35 @@ class HaloExchangerAllGather(HaloExchanger):
right_input_halo.copy_(ag_right_input_halo) right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger): class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm): def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(spatial_group_size, rank) super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
self.world_size = world_size
self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda() nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0) torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu() nccl_id = nccl_id.cpu()
self.handle = inc.init_nccl_comm(nccl_id, rank, world_size) print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None: if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, self.spatial_group_size) left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
return left_input_halo, right_input_halo return left_input_halo, right_input_halo
else: else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size) inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
class HaloExchangerPeer(HaloExchanger): class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1): def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__(spatial_group_size, rank) super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False self.diagnostics = False
self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size
self.left_neighbor = (self.peer_rank + self.spatial_group_size - 1) % self.spatial_group_size
self.right_neighbor = (self.peer_rank + 1) % self.spatial_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.explicit_nhwc = explicit_nhwc self.explicit_nhwc = explicit_nhwc
self.numSM = numSM self.numSM = numSM
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.rank_in_group].zero_()
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True inplace = False if left_input_halo is None and right_input_halo is None else True
...@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger): ...@@ -102,9 +107,9 @@ class HaloExchangerPeer(HaloExchanger):
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True) right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d( pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM, self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.peer_rank], right_tx[self.left_neighbor], left_input_halo, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank] self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group]
) )
# TODO: Add to push_pull_halos_1d kernel # TODO: Add to push_pull_halos_1d kernel
if self.left_zero: if self.left_zero:
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id"); m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id");
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm"); m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("nccl_send", &apex::contrib::nccl_p2p::nccl_send, "nccl_send");
m.def("nccl_recv", &apex::contrib::nccl_p2p::nccl_recv, "nccl_recv");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace"); m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange"); m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay"); m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
......
...@@ -80,75 +80,82 @@ class NcclCommWrapper ...@@ -80,75 +80,82 @@ class NcclCommWrapper
ncclCommDestroy(comm); ncclCommDestroy(comm);
} }
void send(at::Tensor input, int destination) void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclSend(input_ptr, count, ncclType, destination, comm, at::cuda::getCurrentCUDAStream());
});
}
void recv(at::Tensor input, int sender)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclRecv(input_ptr, count, ncclType, sender, comm, at::cuda::getCurrentCUDAStream());
});
}
void left_right_halo_exchange_inplace(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{ {
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart(); ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo); ncclDataType_t ncclType = get_nccl_type(left_output_halo);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc. bool left_zero = (left_rank < 0);
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks. bool right_zero = (right_rank < 0);
int group_rank = rank % group_size;
int group_index = rank / group_size;
int prev_rank = (group_rank + group_size - 1) % group_size;
int next_rank = (group_rank + 1) % group_size;
prev_rank = prev_rank + group_index * group_size;
next_rank = next_rank + group_index * group_size;
size_t left_n = torch::numel(left_output_halo); size_t left_n = torch::numel(left_output_halo);
size_t right_n = torch::numel(right_output_halo); size_t right_n = torch::numel(right_output_halo);
if (group_rank > 0) { assert(left_n > 0 && left_n == right_n);
if (left_zero) {
left_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() { AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() {
// send left (to my_rank - 1) // send left (to my_rank - 1)
ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, prev_rank, comm, stream); ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, left_rank, comm, stream);
// receive left (from my_rank - 1) // receive left (from my_rank - 1)
ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, prev_rank, comm, stream); ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, left_rank, comm, stream);
}); });
} }
if (group_rank < group_size-1) { if (right_zero) {
right_input_halo.zero_();
} else {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() { AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() {
// send right (to my_rank + 1 ) // send right (to my_rank + 1 )
ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, next_rank, comm, stream); ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, right_rank, comm, stream);
// receive right (from my_rank + 1) // receive right (from my_rank + 1)
ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, next_rank, comm, stream); ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, right_rank, comm, stream);
}); });
} }
ncclGroupEnd(); ncclGroupEnd();
if (left_zero) left_input_halo.zero_();
if (right_zero) right_input_halo.zero_();
} }
std::vector<at::Tensor> left_right_halo_exchange(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size) std::vector<at::Tensor> left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{ {
// after halo exchange: // after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank // left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank // right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo); auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo); auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size); left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
return {left_input_halo, right_input_halo}; return {left_input_halo, right_input_halo};
} }
}; };
std::vector<NcclCommWrapper> nccl_comms; class ManagedObjects
{
public:
ManagedObjects()
{
}
~ManagedObjects()
{
for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it)
{
delete *it;
}
}
int add_comm(NcclCommWrapper* comm)
{
int handle = _nccl_comms.size();
_nccl_comms.push_back(comm);
return handle;
}
NcclCommWrapper& get_comm(int handle)
{
assert(handle >= 0 && handle < _nccl_comms.size());
return *_nccl_comms[handle];
}
private:
std::vector<NcclCommWrapper*> _nccl_comms;
};
class ManagedObjects mo;
} // end anonymous namespace } // end anonymous namespace
...@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n) ...@@ -158,7 +165,7 @@ at::Tensor get_unique_nccl_id(int n)
{ {
ncclUniqueId id; ncclUniqueId id;
ncclGetUniqueId(&id); ncclGetUniqueId(&id);
auto id_tensor = torch::empty({n*(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); auto id_tensor = torch::empty({n,(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));
auto id_ptr = id_tensor.data_ptr<uint8_t>(); auto id_ptr = id_tensor.data_ptr<uint8_t>();
size_t offset = 0; size_t offset = 0;
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
...@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) ...@@ -177,38 +184,21 @@ int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>(); auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();
memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId)); memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));
NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks); NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);
int handle = nccl_comms.size(); int handle = mo.add_comm(comm);
nccl_comms.push_back(*comm);
comm = 0L; comm = 0L;
return handle; return handle;
} }
void nccl_send(int handle, at::Tensor input, int destination) void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.send(input, destination);
}
void nccl_recv(int handle, at::Tensor input, int sender)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.recv(input, sender);
}
void left_right_halo_exchange_inplace(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{ {
assert(handle >= 0 && handle < nccl_comms.size()); class NcclCommWrapper& communicator = mo.get_comm(handle);
class NcclCommWrapper& communicator = nccl_comms[handle]; return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo);
return communicator.left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
} }
std::vector<at::Tensor> left_right_halo_exchange(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size) std::vector<at::Tensor> left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo)
{ {
assert(handle >= 0 && handle < nccl_comms.size()); class NcclCommWrapper& communicator = mo.get_comm(handle);
class NcclCommWrapper& communicator = nccl_comms[handle]; return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo);
return communicator.left_right_halo_exchange(left_zero, right_zero, left_output_halo, right_output_halo, group_size);
} }
void add_delay(int delay) void add_delay(int delay)
......
...@@ -26,33 +26,20 @@ int init_nccl_comm( ...@@ -26,33 +26,20 @@ int init_nccl_comm(
int my_rank, int my_rank,
int num_ranks int num_ranks
); );
void nccl_send(
int handle,
at::Tensor input,
int destination
);
void nccl_recv(
int handle,
at::Tensor input,
int sender
);
void left_right_halo_exchange_inplace( void left_right_halo_exchange_inplace(
int handle, int handle,
bool left_zero, int left_rank,
bool right_zero, int right_rank,
at::Tensor left_output_halo, at::Tensor left_output_halo,
at::Tensor right_output_halo, at::Tensor right_output_halo,
at::Tensor left_input_halo, at::Tensor left_input_halo,
at::Tensor right_input_halo, at::Tensor right_input_halo);
int group_size);
std::vector<at::Tensor> left_right_halo_exchange( std::vector<at::Tensor> left_right_halo_exchange(
int handle, int handle,
bool left_zero, int left_rank,
bool right_zero, int right_rank,
at::Tensor left_output_halo, at::Tensor left_output_halo,
at::Tensor right_output_halo, at::Tensor right_output_halo);
int group_size
);
void add_delay(int delay); void add_delay(int delay);
}}} }}}
#endif #endif
...@@ -4,23 +4,40 @@ import peer_memory_cuda as pm ...@@ -4,23 +4,40 @@ import peer_memory_cuda as pm
class PeerMemoryPool(object): class PeerMemoryPool(object):
def __init__(self, rank, world_size, peer_group_size, static_size, dynamic_size): def __init__(self, static_size, dynamic_size, peer_ranks=None):
self.peer_group = rank // peer_group_size rank = torch.distributed.get_rank()
self.peer_rank = rank % peer_group_size world_size = torch.distributed.get_world_size()
self.peer_group_size = peer_group_size ngpus = min(torch.cuda.device_count(), world_size)
peer_group_size = ngpus
peer_group = rank // ngpus
peer_rank_base = peer_group * ngpus
peer_rank = rank - peer_rank_base
if peer_ranks is None:
peer_ranks = [i+peer_rank_base for i in range(peer_group_size)]
peer_rank_start = peer_rank_base
peer_rank_end = peer_rank_start + peer_group_size - 1
for pr in peer_ranks:
assert(pr >= peer_rank_start and pr <= peer_rank_end), "%d :: peer_rank %d not on same node (ranks=[%d,%d])" % (rank, pr, peer_rank_start, peer_rank_end)
self.alignment = 256 self.alignment = 256
self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory # allocate giant pool of device memory
self.raw = pm.allocate_raw(self.static_size+self.dynamic_size) self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl # exchange peer pointers with nccl
raw_ipc = pm.get_raw_ipc_address(self.raw).cuda() raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)] peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc) torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu() peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
self.peer_raw = pm.get_raw_peers(peer_raw_ipcs, self.peer_rank, self.raw)
# extract IPC pointers for ranks on same node
peer_raw = pm.get_raw_peers(peer_raw_ipcs[peer_rank_base:peer_rank_base+ngpus], peer_rank, self.raw)
self.peer_raw = [peer_raw[peer_rank-peer_rank_base] for peer_rank in peer_ranks]
self.static_offset = 0 self.static_offset = 0
self.dynamic_offset = 0 self.dynamic_offset = 0
self.peer_ranks = peer_ranks
def __del__(self): def __del__(self):
pm.free_raw(self.raw) pm.free_raw(self.raw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment