Commit b41c68b3 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Optional inplace halo exchange

parent 778808eb
......@@ -16,8 +16,12 @@ class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__()
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
return right_output_halo, left_output_halo
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
......@@ -26,7 +30,7 @@ class HaloExchangerAllGather(HaloExchanger):
self.local_rank = rank % spatial_group_size
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
......@@ -34,9 +38,13 @@ class HaloExchangerAllGather(HaloExchanger):
all_halos = torch.empty((N,2*Hh*self.spatial_group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.spatial_group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
return left_input_halo, right_input_halo
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
if left_input_halo is None:
return ag_left_input_halo, ag_right_input_halo
else:
left_input_halo.copy_(ag_left_input_halo)
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
......@@ -48,6 +56,90 @@ class HaloExchangerSendRecv(HaloExchanger):
nccl_id = nccl_id.cpu()
self.handle = inc.init_nccl_comm(nccl_id, rank, world_size)
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__()
self.diagnostics = False
self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size
self.left_neighbor = (self.peer_rank + self.spatial_group_size - 1) % self.spatial_group_size
self.right_neighbor = (self.peer_rank + 1) % self.spatial_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_out_halo.shape), left_out_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_out_halo.shape), right_out_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.peer_rank], right_tx[top_neighbor], left_input_halo,
right_output_halo, right_tx[self.peer_rank], left_tx[btm_neighbor], right_input_halo,
self.signals[left_neighbor], self.signals[right_neighbor], self.signals[self.peer_rank]
)
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
......@@ -21,6 +21,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("nccl_send", &apex::contrib::nccl_p2p::nccl_send, "nccl_send");
m.def("nccl_recv", &apex::contrib::nccl_p2p::nccl_recv, "nccl_recv");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
}
......@@ -100,14 +100,9 @@ class NcclCommWrapper
});
}
std::vector<at::Tensor> left_right_halo_exchange(at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
void left_right_halo_exchange_inplace(at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
auto stream = at::cuda::getCurrentCUDAStream();
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
......@@ -137,7 +132,17 @@ class NcclCommWrapper
});
}
ncclGroupEnd();
return {left_input_halo, right_input_halo};
}
std::vector<at::Tensor> left_right_halo_exchange(at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return {left_input_halo, right_input_halo};
}
};
......@@ -190,6 +195,13 @@ void nccl_recv(int handle, at::Tensor input, int sender)
communicator.recv(input, sender);
}
void left_right_halo_exchange_inplace(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange_inplace(left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
......
......@@ -36,6 +36,12 @@ void nccl_recv(
at::Tensor input,
int sender
);
void left_right_halo_exchange_inplace(
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
at::Tensor right_input_halo,
int group_size);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
at::Tensor left_output_halo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment