You need to sign in or sign up before continuing.
Commit 3ade5b26 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add bottleneck block

parent b48898fb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
import fast_bottleneck from maskrcnn_benchmark.utils.registry import Registry
import maskrcnn_benchmark.SpatialBottleneck as fast_bottleneck
import nccl_p2p as inc
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
class FrozenBatchNorm2d(torch.nn.Module): class FrozenBatchNorm2d(torch.jit.ScriptModule):
""" """
BatchNorm2d where the batch statistics and the affine parameters are fixed BatchNorm2d where the batch statistics and the affine parameters are fixed
""" """
...@@ -18,7 +20,9 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -18,7 +20,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
def get_scale_bias(self, nhwc=False): @torch.jit.script_method
def get_scale_bias(self, nhwc):
# type: (bool) -> List[torch.Tensor]
scale = self.weight * self.running_var.rsqrt() scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale bias = self.bias - self.running_mean * scale
if nhwc: if nhwc:
...@@ -29,11 +33,11 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -29,11 +33,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias = bias.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1)
return scale, bias return scale, bias
@torch.jit.script_method
def forward(self, x): def forward(self, x):
scale, bias = self.get_scale_bias() scale, bias = self.get_scale_bias(False)
return x * scale + bias return x * scale + bias
@torch.jit.script @torch.jit.script
def drelu_dscale1(grad_o, output, scale1): def drelu_dscale1(grad_o, output, scale1):
relu_mask = (output>0) relu_mask = (output>0)
...@@ -217,7 +221,11 @@ class Bottleneck(torch.nn.Module): ...@@ -217,7 +221,11 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_stream, nhwc, stride_1x1, scale, bias, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
# TODO: clean up order of tensors # TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3 ctx.downsample = len(conv) > 3
...@@ -232,38 +240,38 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -232,38 +240,38 @@ class SpatialBottleneckFunction(torch.autograd.Function):
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args) outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1) # do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if spatial_group_size > 1: if spatial_group_size > 1:
out1 = outputs[0] out1 = outputs[0]
N,Hs,W,C = list(out1.shape) N,Hs,W,C = list(out1.shape)
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
# copy halos to send buffer top_out1_halo, btm_out1_halo = spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:])
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device) if spatial_group_rank < spatial_group_size-1:
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:]) stream2.wait_stream(stream1)
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:]) with torch.cuda.stream(stream2):
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device) btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)] btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
dist.all_gather(all_halos,send_halos,group=comm) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, btm_fat_halo, args)
top_out1_halo = all_halos[(spatial_group_size+local_rank-1)%spatial_group_size][:,1:,:,:] if spatial_group_rank > 0:
if local_rank > 0: with torch.cuda.stream(stream1):
fat_halo[:,:1,:,:].copy_(top_out1_halo) top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args) top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
btm_out1_halo = all_halos[(local_rank+1)%spatial_group_size][:,:1,:,:] top_out2 = fast_bottleneck.forward_out2_halo(nhwc, top_fat_halo, args)
if local_rank < spatial_group_size-1: inc.add_delay(10)
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_out1_halo) fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
torch.cuda.current_stream().wait_stream(stream1) # compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
out2 = outputs[1] out2 = outputs[1]
if local_rank > 0: if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
out2[:,:1,:,:].copy_(top_out2) out2[:,:1,:,:].copy_(top_out2)
if local_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
out2[:,Hs-1:,:,:].copy_(btm_out2) out2[:,Hs-1:,:,:].copy_(btm_out2)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
...@@ -276,9 +284,11 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -276,9 +284,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx.nhwc = nhwc ctx.nhwc = nhwc
ctx.stride_1x1 = stride_1x1 ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank if spatial_group_size > 1:
ctx.comm = comm ctx.spatial_group_rank = spatial_group_rank
ctx.stream1 = stream1 ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.stream1 = stream1
ctx.stream2 = stream2
return outputs[2] return outputs[2]
# backward relu is not exposed, MUL with mask used now # backward relu is not exposed, MUL with mask used now
...@@ -312,54 +322,52 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -312,54 +322,52 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list) grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads) grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
N,Hs,W,C = list(grad_out2.shape)
relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
ctx.stream2.wait_stream(ctx.stream1)
with torch.cuda.stream(ctx.stream2):
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_relu_halo[:,:1,:,:].zero_()
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
inc.add_delay(10)
# compute wgrad2 for internal cells # compute wgrad2 for internal cells
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos # apply wgrad2 halos
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
if ctx.local_rank > 0: if ctx.spatial_group_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:] top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo) top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo) wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.local_rank < ctx.spatial_group_size-1: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:] btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo) btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo) wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
N,Hs,W,C = list(grad_out2.shape)
ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1):
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
send_halos[:,:1,:,:].copy_(grad_out2[:,:1,:,:])
send_halos[:,1:,:,:].copy_(grad_out2[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*ctx.spatial_group_size,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(ctx.spatial_group_size)]
dist.all_gather(all_halos,send_halos,group=ctx.comm)
relu1 = t_list[12]
fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.local_rank > 0:
top_halo = all_halos[ctx.local_rank-1][:,1:,:,:]
fat_halo[:,:1,:,:].copy_(top_halo)
fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
relu_halo[:,:1,:,:].zero_()
relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
if ctx.local_rank < ctx.spatial_group_size-1:
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:]
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
...@@ -369,20 +377,70 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -369,20 +377,70 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4] z = t_list[4]
relu1 = t_list[12] relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
torch.cuda.current_stream().wait_stream(ctx.stream1) if ctx.spatial_group_rank > 0:
if ctx.local_rank > 0: torch.cuda.current_stream().wait_stream(ctx.stream1)
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.local_rank < ctx.spatial_group_size-1: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
return (None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply spatial_bottleneck_function = SpatialBottleneckFunction.apply
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__()
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
return right_output_halo, left_output_halo
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerAllGather, self).__init__()
self.spatial_group_size = spatial_group_size
self.local_rank = rank % spatial_group_size
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.spatial_group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.spatial_group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
return left_input_halo, right_input_halo
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerSendRecv, self).__init__()
self.world_size = world_size
self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
self.handle = inc.init_nccl_comm(nccl_id, rank, world_size)
def left_right_halo_exchange(self, left_output_halo, right_output_halo):
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
class SpatialBottleneck(torch.nn.Module): class SpatialBottleneck(torch.nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1) # while original implementation places the stride at the first 1x1 convolution(self.conv1)
...@@ -393,7 +451,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -393,7 +451,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1, communicator=None): spatial_parallel_args=None):
super(SpatialBottleneck, self).__init__() super(SpatialBottleneck, self).__init__()
if groups != 1: if groups != 1:
raise RuntimeError('Only support groups == 1') raise RuntimeError('Only support groups == 1')
...@@ -447,26 +505,10 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -447,26 +505,10 @@ class SpatialBottleneck(torch.nn.Module):
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator # spatial communicator
self.spatial_group_size = spatial_group_size if spatial_parallel_args is None:
if spatial_group_size > 1: self.spatial_parallel_args = (1, 0, None, None, None)
world_size = dist.get_world_size()
num_groups = world_size // spatial_group_size
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
if communicator is None:
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else: else:
self.spatial_args = 1, 0, None, None self.spatial_parallel_args = spatial_parallel_args
return return
def forward(self, x): def forward(self, x):
...@@ -483,7 +525,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -483,7 +525,7 @@ class SpatialBottleneck(torch.nn.Module):
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -510,3 +552,10 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -510,3 +552,10 @@ class SpatialBottleneck(torch.nn.Module):
out = self.relu(out) out = self.relu(out)
return out return out
_HALO_EXCHANGERS = Registry({
"HaloExchangerNoComm": HaloExchangerNoComm,
"HaloExchangerAllGather": HaloExchangerAllGather,
"HaloExchangerSendRecv": HaloExchangerSendRecv,
})
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include "peer_memory_cuda.cuh" #include "peer_memory_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::peer_memory::allocate_raw, "allocate_raw"); m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::peer_memory::free_raw, "free_raw"); m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("get_raw_ipc_address", &apex::peer_memory::get_raw_ipc_address, "get_raw_ipc_address"); m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::peer_memory::get_raw_peers, "get_raw_peers"); m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::peer_memory::blob_view_half, "blob_view_half"); m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
m.def("blob_view_float", &apex::peer_memory::blob_view_float, "blob_view_float"); m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float");
m.def("blob_view_int", &apex::peer_memory::blob_view_int, "blob_view_int"); m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int");
m.def("push_pull_halos_1d", &apex::peer_memory::push_pull_halos_1d, "push_pull_halos_1d"); m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d");
} }
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <cassert> #include <cassert>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include "nccl.h"
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \ #define CUDACHECK(cmd) do { \
...@@ -214,9 +215,25 @@ __global__ void push_pull_halos_1d_kernel( ...@@ -214,9 +215,25 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
} }
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
{
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay_nanoseconds);
*counter = new_counter;
}
}
} }
namespace apex { namespace peer_memory { namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size) int64_t allocate_raw(int64_t size)
{ {
...@@ -460,5 +477,5 @@ void push_pull_halos_1d( ...@@ -460,5 +477,5 @@ void push_pull_halos_1d(
} ); } );
} }
} } } } }
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#ifndef _peer_memory_h_ #ifndef _peer_memory_h_
#define _peer_memory_h_ #define _peer_memory_h_
namespace apex { namespace peer_memory { namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size); int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw); void free_raw(int64_t raw);
at::Tensor get_raw_ipc_address(int64_t raw); at::Tensor get_raw_ipc_address(int64_t raw);
...@@ -43,5 +43,5 @@ namespace apex { namespace peer_memory { ...@@ -43,5 +43,5 @@ namespace apex { namespace peer_memory {
at::Tensor btm_signal, // btm input signal in receiver device memory at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank at::Tensor waits // top and btm signals for this rank
); );
} } } } }
#endif #endif
...@@ -641,6 +641,20 @@ if "--peer_memory" in sys.argv: ...@@ -641,6 +641,20 @@ if "--peer_memory" in sys.argv:
) )
) )
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p",
sources=[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
setup( setup(
name="apex", name="apex",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment