Unverified Commit fed20d2a authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1340 from NVIDIA/peer_memory

Peer memory halo exchange
parents d89f5e66 5698eeeb
from .bottleneck import Bottleneck, SpatialBottleneck from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
import functools as func
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
import fast_bottleneck import fast_bottleneck
import nccl_p2p_cuda as inc
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
class FrozenBatchNorm2d(torch.nn.Module): def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule):
""" """
BatchNorm2d where the batch statistics and the affine parameters are fixed BatchNorm2d where the batch statistics and the affine parameters are fixed
""" """
...@@ -18,7 +31,9 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -18,7 +31,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
def get_scale_bias(self, nhwc=False): @torch.jit.script_method
def get_scale_bias(self, nhwc):
# type: (bool) -> List[torch.Tensor]
scale = self.weight * self.running_var.rsqrt() scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale bias = self.bias - self.running_mean * scale
if nhwc: if nhwc:
...@@ -29,11 +44,11 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -29,11 +44,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias = bias.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1)
return scale, bias return scale, bias
@torch.jit.script_method
def forward(self, x): def forward(self, x):
scale, bias = self.get_scale_bias() scale, bias = self.get_scale_bias(False)
return x * scale + bias return x * scale + bias
@torch.jit.script @torch.jit.script
def drelu_dscale1(grad_o, output, scale1): def drelu_dscale1(grad_o, output, scale1):
relu_mask = (output>0) relu_mask = (output>0)
...@@ -147,6 +162,7 @@ class Bottleneck(torch.nn.Module): ...@@ -147,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -170,10 +186,33 @@ class Bottleneck(torch.nn.Module): ...@@ -170,10 +186,33 @@ class Bottleneck(torch.nn.Module):
for p in self.parameters(): for p in self.parameters():
with torch.no_grad(): with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.w_scale is None:
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -185,8 +224,9 @@ class Bottleneck(torch.nn.Module): ...@@ -185,8 +224,9 @@ class Bottleneck(torch.nn.Module):
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
else:
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -217,7 +257,12 @@ class Bottleneck(torch.nn.Module): ...@@ -217,7 +257,12 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function): class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv): def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors # TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3 ctx.downsample = len(conv) > 3
...@@ -226,59 +271,152 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -226,59 +271,152 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args.append(scale[3]) args.append(scale[3])
args.append(bias[3]) args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it # here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in # alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args) outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if spatial_group_size > 1: if spatial_group_size > 1:
out1 = outputs[0] out1 = outputs[0]
if explicit_nhwc:
N,Hs,W,C = list(out1.shape) N,Hs,W,C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
else:
N,C,Hs,W = list(out1.shape)
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream()) stream1.wait_stream(torch.cuda.current_stream())
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1): with torch.cuda.stream(stream1):
# copy halos to send buffer if explicit_nhwc:
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device) top_out1_halo = out1_pad[:,:1,:,:]
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:]) btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:]) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device) else:
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)] top_out1_halo = out1_pad[:,:,:1,:]
dist.all_gather(all_halos,send_halos,group=comm) btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
top_out1_halo = all_halos[(spatial_group_size+local_rank-1)%spatial_group_size][:,1:,:,:] if spatial_method == 1:
if local_rank > 0: # overlap mid convolution with halo transfer
fat_halo[:,:1,:,:].copy_(top_out1_halo) if spatial_group_rank < spatial_group_size-1:
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) stream2.wait_stream(stream1)
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args) with torch.cuda.stream(stream2):
btm_out1_halo = all_halos[(local_rank+1)%spatial_group_size][:,:1,:,:] if explicit_nhwc:
if local_rank < spatial_group_size-1: btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_out1_halo) btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args) else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if use_delay_kernel: inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
torch.cuda.current_stream().wait_stream(stream1) torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
out2 = outputs[1] out2 = outputs[1]
if local_rank > 0: if explicit_nhwc:
out2[:,:1,:,:].copy_(top_out2) top_out2_halo = out2[:,:1,:,:]
if local_rank < spatial_group_size-1: btm_out2_halo = out2[:,Hs-1:,:,:]
out2[:,Hs-1:,:,:].copy_(btm_out2) else:
top_out2_halo = out2[:,:,:1,:]
btm_out2_halo = out2[:,:,Hs-1:,:]
if spatial_method == 1:
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) # wait for halo transfers to finish
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2):
w1by3 = args[2][:,2:3,:,:].clone()
btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
if spatial_group_size > 1: if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo])) if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else: else:
ctx.save_for_backward(*(args+outputs)) ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu # save relu outputs for drelu
ctx.nhwc = nhwc ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1 ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank if spatial_group_size > 1:
ctx.comm = comm ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.use_delay_kernel = use_delay_kernel
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1 ctx.stream1 = stream1
ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2] return outputs[2]
# backward relu is not exposed, MUL with mask used now # backward relu is not exposed, MUL with mask used now
...@@ -286,9 +424,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -286,9 +424,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2] out1_pad = ctx.saved_tensors[-1]
btm_out1_halo = ctx.saved_tensors[-1] outputs = ctx.saved_tensors[-4:-1]
outputs = ctx.saved_tensors[-5:-2]
else: else:
outputs = ctx.saved_tensors[-3:] outputs = ctx.saved_tensors[-3:]
...@@ -310,58 +447,79 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -310,58 +447,79 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.downsample: if ctx.downsample:
t_list.append(ctx.saved_tensors[10]) t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list) grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads) wgrad3_stream = torch.cuda.Stream()
wgrad3_stream.wait_stream(torch.cuda.current_stream())
# compute wgrad2 for internal cells grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
if ctx.local_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.local_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
# do halo exchange of grad_out2 here # do halo exchange of grad_out2 here
# compute halo cells for grad_out1 # compute halo cells for grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
if ctx.explicit_nhwc:
N,Hs,W,C = list(grad_out2.shape) N,Hs,W,C = list(grad_out2.shape)
else:
N,C,Hs,W = list(grad_out2.shape)
relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream()) ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer # copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_method == 1 or ctx.spatial_method == 2:
send_halos[:,:1,:,:].copy_(grad_out2[:,:1,:,:]) # 1 -> halo recompute approach
send_halos[:,1:,:,:].copy_(grad_out2[:,Hs-1:,:,:]) # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
all_halos = torch.empty((N,2*ctx.spatial_group_size,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_group_rank < ctx.spatial_group_size-1:
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(ctx.spatial_group_size)] ctx.stream2.wait_stream(ctx.stream1)
dist.all_gather(all_halos,send_halos,group=ctx.comm) with torch.cuda.stream(ctx.stream2):
relu1 = t_list[12] if ctx.explicit_nhwc:
fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
if ctx.local_rank > 0: btm_fat_halo[:,2:,:,:].copy_(btm_halo)
top_halo = all_halos[ctx.local_rank-1][:,1:,:,:] btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
fat_halo[:,:1,:,:].copy_(top_halo) btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) btm_fat_relu_halo[:,2:,:,:].zero_()
relu_halo[:,:1,:,:].zero_() else:
relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo) btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] btm_fat_halo[:,:,2:,:].copy_(btm_halo)
if ctx.local_rank < ctx.spatial_group_size-1: btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:] btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) btm_fat_relu_halo[:,:,2:,:].zero_()
fat_halo[:,2:,:,:].copy_(btm_halo) btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) if ctx.explicit_nhwc:
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
if ctx.explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:1,:,:].zero_()
top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:,:1,:].zero_()
top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
if ctx.use_delay_kernel: inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2) if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
# apply halo cells to grad_out1 # apply halo cells to grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -369,17 +527,69 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -369,17 +527,69 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4] z = t_list[4]
relu1 = t_list[12] relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
if ctx.explicit_nhwc:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
else:
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1) torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.local_rank > 0: if ctx.explicit_nhwc:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) else:
if ctx.local_rank < ctx.spatial_group_size-1: grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape)))) elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
else:
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
w1by3 = w[:,:1,:,:].clone()
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream2):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:,:1,:,:].clone()
top_grad_out1 = grad_out1[:,:1,:,:]
else:
top_relu_halo = relu1[:,:,:1,:].clone()
top_grad_out1 = grad_out1[:,:,:1,:]
w1by3 = w[:,2:,:,:].clone()
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) wgrad1_stream = torch.cuda.Stream()
wgrad1_stream.wait_stream(torch.cuda.current_stream())
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
with torch.cuda.stream(wgrad3_stream):
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
with torch.cuda.stream(wgrad1_stream):
fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
torch.cuda.current_stream().wait_stream(wgrad3_stream)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
torch.cuda.current_stream().wait_stream(wgrad1_stream)
return (None, None, None, None, None, None, None, None, *grads) return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply spatial_bottleneck_function = SpatialBottleneckFunction.apply
...@@ -393,7 +603,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -393,7 +603,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1, communicator=None): spatial_parallel_args=None):
super(SpatialBottleneck, self).__init__() super(SpatialBottleneck, self).__init__()
if groups != 1: if groups != 1:
raise RuntimeError('Only support groups == 1') raise RuntimeError('Only support groups == 1')
...@@ -422,6 +632,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -422,6 +632,7 @@ class SpatialBottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -434,6 +645,8 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -434,6 +645,8 @@ class SpatialBottleneck(torch.nn.Module):
for w in self.w_conv: for w in self.w_conv:
kaiming_uniform_(w, a=1) kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage # TODO: prevent unsupported case usage
# support cases # support cases
# native cudnn # native cudnn
...@@ -447,30 +660,45 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -447,30 +660,45 @@ class SpatialBottleneck(torch.nn.Module):
p.data = p.data.permute(0,2,3,1).contiguous() p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator # spatial communicator
self.spatial_group_size = spatial_group_size if spatial_parallel_args is None:
if spatial_group_size > 1: self.spatial_parallel_args = (1, 0, None, None, 0, False)
world_size = dist.get_world_size() else:
num_groups = world_size // spatial_group_size self.spatial_parallel_args = spatial_parallel_args
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
if communicator is None:
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else:
self.spatial_args = 1, 0, None, None
return return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x): def forward(self, x):
if self.use_cudnn: if self.use_cudnn:
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
if self.w_scale is None:
# calculate scale/bias from registered buffers # calculate scale/bias from registered buffers
# TODO: make this better # TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
...@@ -482,8 +710,9 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -482,8 +710,9 @@ class SpatialBottleneck(torch.nn.Module):
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4) w_scale.append(s4)
w_bias.append(b4) w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out return out
if self.explicit_nhwc: if self.explicit_nhwc:
...@@ -510,3 +739,4 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -510,3 +739,4 @@ class SpatialBottleneck(torch.nn.Module):
out = self.relu(out) out = self.relu(out)
return out return out
import os
import torch import torch
from maskrcnn_benchmark.modeling.backbone.resnet import Bottleneck from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck
from maskrcnn_benchmark.layers.nhwc import nhwc_to_nchw_transform, nchw_to_nhwc_transform from apex.contrib.bottleneck import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
from maskrcnn_benchmark.layers.nhwc.batch_norm import FrozenBatchNorm2d_NHWC from apex.contrib.peer_memory import PeerMemoryPool
from apex.contrib.bottleneck import Bottleneck as FastBottleneck
from apex.contrib.bottleneck import SpatialBottleneck
def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spatial_group_size, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc): def ground_truth_bottleneck(C, dtype, explicit_nhwc):
# inputs + modules bottleneck = Bottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc)
bottleneck.to(dtype=dtype, device='cuda')
for p in bottleneck.parameters():
torch.distributed.broadcast(p, 0)
for b in bottleneck.buffers():
torch.distributed.broadcast(b, 0)
return bottleneck
def print_bottleneck_p_and_b(bottleneck):
with torch.no_grad(): with torch.no_grad():
input_shape = [1, in_channels] + list(shape) for n,p in bottleneck.named_parameters():
x = torch.randn(input_shape, dtype=numtype, device=device) print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
if nhwc: for n,p in bottleneck.named_buffers():
x = nchw_to_nhwc_transform(x).contiguous() print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
x.requires_grad = True
print(x.shape, x.stride())
def has_nan(x):
#if spatial_group_size > 1: if isinstance(x, list) or isinstance(x, tuple):
# fast = False # hack so fast bottleneck can be run against distributed bottleneck for xx in x:
#if spatial_group_size == 1: if torch.any(torch.isnan(xx)):
# fast = False return True
return False
if fast: elif isinstance(x, dict):
if spatial_group_size == 1: for k,v in x.items():
bottleneck = FastBottleneck( if torch.any(torch.isnan(v)):
in_channels=in_channels, return True
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
else: else:
bottleneck = SpatialBottleneck( return torch.any(torch.isnan(x))
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels, def rel_diff_t(xx1, xx2):
stride=stride, return ((xx1 - xx2).norm(p=2,dtype=torch.float32) / (xx1 + xx2).norm(p=2,dtype=torch.float32)).item()
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True, def rel_diff(x1, x2):
spatial_group_size=spatial_group_size) if isinstance(x1, list) or isinstance(x1, tuple):
return [rel_diff_t(xx1,xx2) for xx1,xx2 in zip(x1,x2)]
elif isinstance(x1, dict):
return [rel_diff_t(xx1, xx2) for (k1,xx1), (k2,xx2) in zip(x1.items(),x2.items())]
else: else:
bottleneck = Bottleneck( return rel_diff_t(x1,x2)
in_channels,
bottleneck_channels,
out_channels, def graph_it(bottleneck, x):
num_groups, print("Graphing")
stride_in_1x1,
stride,
dilation,
norm_func,
nhwc,
spatial_group_size)
bottleneck = bottleneck.to(dtype=numtype,device=device)
weights = dict(bottleneck.named_parameters())
if ref is not None:
ref_x, _, ref_weights = ref
Hs,H = x.shape[1], ref_x.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_x = ref_x[:,rank*Hs:(rank+1)*Hs,:,:]
x.copy_(ref_x)
assert(len(weights) == len(ref_weights)), "Reference weights and weights don't match"
for k in weights.keys():
weights[k].copy_(ref_weights[k])
# forward
out = bottleneck(x)
# gradient output
with torch.no_grad(): with torch.no_grad():
grad_out = torch.randn_like(out) x = x.clone()
if ref is not None: x.grad = None
_, ref_grad_out, _ = ref x.requires_grad = True
Hs,H = grad_out.shape[1], ref_grad_out.shape[1] return torch.cuda.make_graphed_callables(bottleneck, (x,))
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_grad_out = ref_grad_out[:,rank*Hs:(rank+1)*Hs,:,:]
grad_out.copy_(ref_grad_out)
# backward
out.backward(grad_out)
def clone_inputs(bottleneck, x, dy=None):
with torch.no_grad(): with torch.no_grad():
dgrad = x.grad.detach() x = x.clone()
x.grad = None
x.requires_grad = True
if dy is None:
y = bottleneck(x)
dy = torch.randn_like(y) / 1e2
torch.distributed.broadcast(dy, 0)
return x, dy
def fprop_and_bprop(bottleneck, x, dy):
y = bottleneck(x)
y.backward(dy)
dgrad = x.grad.detach()
wgrad = {} wgrad = {}
for n,p in bottleneck.named_parameters(): for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach() wgrad[n] = p.grad.detach()
return x, y, dy, dgrad, wgrad
if world_size > 1: def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
if spatial_group_size == 1: if memory_format == 1:
# broadcast x, grad_out and weights from rank 0 # 1 -> explicit nhwc
explicit_nhwc = True
with torch.no_grad(): with torch.no_grad():
torch.distributed.broadcast(x,0) x = torch.randn([N,H,W,C], dtype=dtype, device='cuda')
torch.distributed.broadcast(grad_out,0) torch.distributed.broadcast(x, 0)
for k in weights.keys(): x, dy = clone_inputs(bottleneck, x)
torch.distributed.broadcast(weights[k],0) return fprop_and_bprop(bottleneck, x, dy)
else: else:
# gather dgrad (x.grad), sum wgrad (weights) and out # 2 -> native nhwc
N,Hs,W,C = dgrad.shape # 3 -> nchw
H = Hs * spatial_group_size explicit_nhwc = False
dgrad_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device) assert(False), "Not implemented yet"
dgrad_tensors = [dgrad_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(dgrad_tensors, dgrad)
dgrad = dgrad_gathered
N,Hs,W,C = list(out.shape)
H = Hs * spatial_group_size
out_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
out_tensors= [out_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(out_tensors, out)
out = out_gathered
for k in wgrad.keys():
w = wgrad[k].to(dtype=torch.float64)
torch.distributed.all_reduce(w)
wgrad[k].copy_(w.to(dtype=wgrad[k].dtype))
#torch.distributed.all_reduce(wgrad[k])
return x, out, grad_out, weights, dgrad, wgrad
def module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args):
r = []
for ia in init_args:
shape = ia[0:4]
args = ia[4:]
rr = []
ref = None
for spatial_group_size in spatial_group_sizes:
N,H,W,C = shape
H = H//spatial_group_size
x, out, grad_out, weights, dgrad, wgrad = single_module_test(ref, rank, world_size, numtype, device, [H,W], fast, spatial_group_size, *args)
if ref is None:
assert(spatial_group_size == 1), "Wrong reference weights"
ref = x, grad_out, weights
if rank == 0:
rr.append( (out, dgrad, wgrad) )
if world_size > 1: torch.distributed.barrier()
r.append(rr)
return r
def main(): def print_ground_truth(gt):
total_num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 x, y, dy, dgrad, wgrad = gt
distributed = total_num_gpus > 1 if has_nan(y) or has_nan(dgrad) or has_nan(wgrad):
ngpus = torch.cuda.device_count() print("Error! Ground truth has NAN")
else:
print("Ok! No NAN found in ground truth")
if distributed:
torch.distributed.init_process_group("nccl") def apply_to_different_bottleneck(gt, bottleneck):
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size() with torch.no_grad():
is_master = True if rank == 0 else False x, _, dy, _, _ = gt
local_rank = rank % ngpus x, dy = clone_inputs(bottleneck, x, dy)
torch.cuda.set_device(local_rank) return fprop_and_bprop(bottleneck, x, dy)
spatial_group_size = total_num_gpus
def compare_single_field(results, f1, f2, l0, l1, l2):
if has_nan(f1) and has_nan(f2):
results[l0] = "both NAN"
elif has_nan(f1):
results[l0] = "%s.%s NAN" % (l1, l0)
elif has_nan(f2):
results[l0] = "%s.%s NAN" % (l2, l0)
else: else:
rank, local_rank, is_master, world_size, spatial_group_size = 0, 0, True, 1, 1 results[l0] = "%s" % (str(rel_diff(f1,f2)))
def compare(gt, bt):
x1, y1, dy1, dgrad1, wgrad1 = gt
x2, y2, dy2, dgrad2, wgrad2 = bt
results = {}
compare_single_field(results, y1, y2, "y", "gt", "bt")
compare_single_field(results, dy1, dy2, "dy", "gt", "bt")
compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt")
compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt")
for i in range(torch.distributed.get_world_size()):
if i == torch.distributed.get_rank():
print(i,results)
torch.distributed.barrier()
def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args):
spatial_bottleneck = SpatialBottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc,spatial_parallel_args=spatial_parallel_args)
spatial_bottleneck.to(dtype=dtype, device='cuda')
with torch.no_grad():
sp = {}
for n,p in spatial_bottleneck.named_parameters():
sp[n] = p
for n,p in gt_bottleneck.named_parameters():
sp[n].copy_(p)
sb = {}
for n,b in spatial_bottleneck.named_buffers():
sb[n] = b
for n,b in gt_bottleneck.named_buffers():
sb[n].copy_(b)
return spatial_bottleneck
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):
assert(explicit_nhwc), "Only tested for explicit nhwc"
x, _, dy, _, _ = gt
N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel
dtype = x.dtype
spatial_group_size = world_size
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel = False
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
torch.use_deterministic_algorithms(True) with torch.no_grad():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
norm_func = FrozenBatchNorm2d_NHWC
init_args = [
(1, 200, 336, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 100, 168, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 100, 168, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 25, 42, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 168, 100, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 168, 100, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 42, 25, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
]
init_args = init_args[0:1]
# pad H to account for spatial distribution
padded_init_args = []
for ia in init_args:
N,H,W,C = ia[0:4]
m = spatial_group_size * H // (25 if H < W else 42)
H = ((H + m - 1) // m) * m
args = tuple( [N,H,W,C] + list(ia[4:]) )
padded_init_args.append(args)
init_args = padded_init_args
if rank == 0:
for ia in init_args:
print(ia)
spatial_group_sizes = [1]
if spatial_group_size > 1:
spatial_group_sizes.append(spatial_group_size)
numtype, device, fast = torch.float16, 'cuda', True
r = module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args)
if world_size > 1: torch.distributed.barrier()
if rank == 0:
for rr in r:
print("***")
for out, dgrad, wgrad in rr:
gr = [("out",out.norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
print(gr)
if len(rr) == 2:
out1, dgrad1, wgrad1 = rr[0]
out2, dgrad2, wgrad2 = rr[1]
rtol = 1e-1
out_atol = out1.abs().max().item() * rtol
dgrad_atol = dgrad1.abs().max().item() * rtol
wgrad_atol = {}
for k in wgrad1.keys():
wgrad_atol[k] = wgrad1[k].abs().max().item() * rtol
gr = [("out",torch.allclose(out1,out2,rtol,out_atol,equal_nan=True))]
gr = gr + [("dgrad",torch.allclose(dgrad1,dgrad2,rtol,dgrad_atol,equal_nan=True))]
gr = gr + [(k+".wgrad",torch.allclose(wgrad1[k],wgrad2[k],rtol,wgrad_atol[k],equal_nan=True)) for k in wgrad1.keys()]
print(gr)
gr = [("out",(out1-out2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",(dgrad1-dgrad2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",(wgrad1[k]-wgrad2[k]).norm(p=2,dtype=torch.float64).item()) for k in wgrad1.keys()]
print(gr)
N,H,W,C = out1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs-1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs+1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
N,H,W,C = dgrad1.shape
Hs = H // spatial_group_size Hs = H // spatial_group_size
Ht = Hs-2 xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) xs.requires_grad = True
Ht = Hs-1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) spatial_bottleneck = graph_it(spatial_bottleneck, xs)
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) _, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)
Ht = Hs
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) # gather output pieces
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) for n,p in wgrad.items():
Ht = Hs+1 if fp32_reduce:
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5]))) p32 = p.float()
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5]))) torch.distributed.all_reduce(p32)
p.copy_(p32.half())
else:
if world_size > 1: torch.distributed.barrier() torch.distributed.all_reduce(p)
ys = [torch.empty_like(y) for _ in range(spatial_group_size)]
torch.distributed.all_gather(ys,y)
y = torch.cat(ys,dim=1)
dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)]
torch.distributed.all_gather(dgrads,dgrad)
dgrad = torch.cat(dgrads,dim=1)
return x, y, dy, dgrad, wgrad
def main():
torch.use_deterministic_algorithms(True)
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
explicit_nhwc = True
dtype = torch.float16
N, C, H, W = 1, 64, 200, 336
Hs = ((H+8*world_size-1) // (8*world_size)) * 8
H = Hs*world_size
gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)
gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)
# verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None)
bt = apply_to_different_bottleneck(gt, spatial_bottleneck)
compare(gt, bt)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
spatial_group_size = world_size
spatial_communicator = None
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2)
if __name__ == "__main__": if __name__ == "__main__":
......
import torch
import torch.distributed as dist
from torch import nn
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, spatial_group_size, rank):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
spatial_rank = rank % spatial_group_size
self.left_zero = True if spatial_rank == 0 else False
self.right_zero = True if spatial_rank == spatial_group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__(spatial_group_size, rank)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerAllGather, self).__init__(spatial_group_size, rank)
self.spatial_group_size = spatial_group_size
self.local_rank = rank % spatial_group_size
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.spatial_group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.spatial_group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerSendRecv, self).__init__(spatial_group_size, rank)
self.world_size = world_size
self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
self.handle = inc.init_nccl_comm(nccl_id, rank, world_size)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__(spatial_group_size, rank)
self.diagnostics = False
self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size
self.left_neighbor = (self.peer_rank + self.spatial_group_size - 1) % self.spatial_group_size
self.right_neighbor = (self.peer_rank + 1) % self.spatial_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.peer_rank], right_tx[self.left_neighbor], left_input_halo,
right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo,
self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
...@@ -1608,18 +1608,1318 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, ...@@ -1608,18 +1608,1318 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
namespace { namespace {
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
AFTERACT_TENSOR,
GEN_INDEX_TENSOR,
MASK_TOP_TENSOR,
MASK_BOTTOM_TENSOR,
MASK_TENSOR,
THRESHOLD_TOP_TENSOR,
THRESHOLD_BOTTOM_TENSOR,
};
using masked_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('E') // after act for masked
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
DGRAD_INPUT_TENSOR,
DGRAD_OPTIONAL_TENSOR,
DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
};
using dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_add_descriptors
create_dconv_add_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_add_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_mask_descriptors
create_dconv_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
void
run_conv_add_scale_bias_activation(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// create an add node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<AFTERACT_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERACT_TENSOR>(tensors))
.setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setbDesc(std::get<AFTERACT_TENSOR>(tensors))
.settDesc(std::get<MASK_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
if (devPtrI) {
std::array<cudnn_frontend::Operation const*, 10> ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(8, data_ptrs)
.setUids(8, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} else {
std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
}
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_add_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_add_descriptors tensors = create_dconv_add_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
struct bottleneck_forward_status { struct bottleneck_forward_status {
int64_t dimA[4]; int64_t dimA[4];
int64_t filterdimA1[4]; int64_t filterdimA1[4];
int64_t filterdimA2[4]; int64_t filterdimA2[4];
int64_t filterdimA2hh[4];
int64_t filterdimA3[4]; int64_t filterdimA3[4];
int64_t filterdimA4[4]; int64_t filterdimA4[4];
int64_t threshdim[4];
int axis[4]; int axis[4];
int64_t outdimA0[4]; int64_t outdimA0[4];
int64_t outdimA1[4]; int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; int64_t outdimA2[4];
int64_t outdimA3[4]; int64_t outdimA3[4];
int64_t outdimA4[4]; int64_t outdimA4[4];
...@@ -1633,6 +2933,7 @@ struct bottleneck_forward_status { ...@@ -1633,6 +2933,7 @@ struct bottleneck_forward_status {
int64_t outdim0[4]; // halo input shape int64_t outdim0[4]; // halo input shape
int64_t outdim1[4]; int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape int64_t outdim4[4]; // halo output shape
...@@ -1641,8 +2942,10 @@ struct bottleneck_forward_status { ...@@ -1641,8 +2942,10 @@ struct bottleneck_forward_status {
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w // All dim calculation after this order of n,c,h,w
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -1668,10 +2971,18 @@ struct bottleneck_forward_status { ...@@ -1668,10 +2971,18 @@ struct bottleneck_forward_status {
filterdimA4[dim] = inputs[10].size(axis[dim]); filterdimA4[dim] = inputs[10].size(axis[dim]);
} }
} }
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend // output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
...@@ -1690,6 +3001,13 @@ struct bottleneck_forward_status { ...@@ -1690,6 +3001,13 @@ struct bottleneck_forward_status {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -1715,6 +3033,7 @@ struct bottleneck_forward_status { ...@@ -1715,6 +3033,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -1726,6 +3045,7 @@ struct bottleneck_forward_status { ...@@ -1726,6 +3045,7 @@ struct bottleneck_forward_status {
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]]; outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]]; outdim4[dim] = outdimA4[axis[dim]];
...@@ -1821,6 +3141,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_ ...@@ -1821,6 +3141,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
return halo_y2; return halo_y2;
} }
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector<at::Tensor> inputs, at::Tensor w1by3, at::Tensor out2_part_halo) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = w1by3.data_ptr<at::Half>(); // C,C,1,3
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = slim_halo_y1.data_ptr<at::Half>();
at::Half* prev_out2 = out2_part_halo.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_add_scale_bias_activation(forward_state.outdimA4,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2hh,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
prev_out2);
return halo_y2;
}
void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed; std::cout << std::fixed;
...@@ -1859,6 +3214,86 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at: ...@@ -1859,6 +3214,86 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
} }
void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
forward_state.threshdim,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2); // axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed; std::cout << std::fixed;
...@@ -1932,10 +3367,12 @@ struct bottleneck_backward_state { ...@@ -1932,10 +3367,12 @@ struct bottleneck_backward_state {
int64_t filterdimA3[4]; int64_t filterdimA3[4];
int64_t filterdimA4[4]; int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3 int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int64_t threshdim[4];
int axis[4]; int axis[4];
int64_t outdimA1[4]; // grad_out1 int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2 int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4]; int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
...@@ -1953,9 +3390,11 @@ struct bottleneck_backward_state { ...@@ -1953,9 +3390,11 @@ struct bottleneck_backward_state {
int64_t filterdim2hh[4]; // Cin,1,3,Cout int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4]; int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim1h[4]; int64_t outdim1h[4];
int64_t outdim1hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) { void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions // setup dimensions
...@@ -1965,6 +3404,7 @@ struct bottleneck_backward_state { ...@@ -1965,6 +3404,7 @@ struct bottleneck_backward_state {
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w // All dim calculation after this order of n,c,h,w
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -2001,6 +3441,7 @@ struct bottleneck_backward_state { ...@@ -2001,6 +3441,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend // output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
...@@ -2022,6 +3463,13 @@ struct bottleneck_backward_state { ...@@ -2022,6 +3463,13 @@ struct bottleneck_backward_state {
for (int dim = 0; dim < 2; dim++) { for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
} }
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0]; outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0]; outdimA2[1] = filterdimA2[0];
...@@ -2051,9 +3499,11 @@ struct bottleneck_backward_state { ...@@ -2051,9 +3499,11 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view // Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
axis[0] = 0; axis[0] = 0;
...@@ -2063,9 +3513,11 @@ struct bottleneck_backward_state { ...@@ -2063,9 +3513,11 @@ struct bottleneck_backward_state {
} }
for (int dim=0;dim<4;dim++) { for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]]; outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]]; outdim1h[dim] = outdimA1h[axis[dim]];
outdim1hh[dim] = outdimA1hh[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]]; filterdim2hh[dim] = filterdimA2hh[axis[dim]];
} }
} }
...@@ -2102,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_ ...@@ -2102,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_
return outputs; return outputs;
} }
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) { void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2 // dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>(); at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>(); at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad // wgrad
auto wgrad3 = outputs[3]; auto wgrad3 = outputs[3];
at::Half* dw3 = wgrad3.data_ptr<at::Half>(); at::Half* dw3 = wgrad3.data_ptr<at::Half>();
...@@ -2129,6 +3574,22 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std ...@@ -2129,6 +3574,22 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
dw3, dw3,
dy3, dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// dgrad // dgrad
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format); auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
...@@ -2178,6 +3639,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2178,6 +3639,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad // fused dgrad
//printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);
run_dconv_drelu_dscale(backward_state.outdimA1, run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1, backward_state.padA1,
backward_state.convstrideA, backward_state.convstrideA,
...@@ -2194,6 +3656,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2194,6 +3656,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
return grad_out1; return grad_out1;
} }
at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale_mask(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
backward_state.threshdim,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2);
return grad_out1;
}
// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C]
at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, at::Tensor w1by3, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
//at::Half* w = inputs[2].data_ptr<at::Half>(); // use w1by3 instead, which is a sliced version of inputs[2]
at::Half* w = w1by3.data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_add_drelu_dscale(backward_state.outdimA1hh,
backward_state.padA2, // 0,1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // C,1,3,C
backward_state.outdimA2hh,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h,
pdy1h);
return grad_out1_halo;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] // perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
...@@ -2233,7 +3777,38 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1 ...@@ -2233,7 +3777,38 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return grad_out1_halo; return grad_out1_halo;
} }
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) { void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
...@@ -2262,8 +3837,7 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v ...@@ -2262,8 +3837,7 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
dw2, dw2,
dy2, dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
return wgrad2;
} }
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C] // compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
...@@ -2306,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s ...@@ -2306,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s
return wgrad2_halo; return wgrad2_halo;
} }
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) { void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
...@@ -2404,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at ...@@ -2404,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dx_conv4 = inputs[11].data_ptr<at::Half>(); dx_conv4 = inputs[11].data_ptr<at::Half>();
} }
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad // dgrad
w = inputs[1].data_ptr<at::Half>(); w = inputs[1].data_ptr<at::Half>();
auto grad_x = outputs[0]; auto grad_x = outputs[0];
...@@ -2460,8 +4041,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at ...@@ -2460,8 +4041,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>()); DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
...@@ -2474,13 +4053,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -2474,13 +4053,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init"); m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init");
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward"); m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward"); m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward"); m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward"); m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward");
m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward"); m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
} }
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id");
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("nccl_send", &apex::contrib::nccl_p2p::nccl_send, "nccl_send");
m.def("nccl_recv", &apex::contrib::nccl_p2p::nccl_recv, "nccl_recv");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace {
__global__ void AddDelay_kernel(const int delay, int* counter) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay);
*counter = new_counter;
}
}
class NcclCommWrapper
{
private:
ncclComm_t comm;
int rank, world_size;
ncclDataType_t get_nccl_type(at::Tensor input)
{
switch (input.scalar_type())
{
case at::ScalarType::Half:
return ncclFloat16;
case at::ScalarType::Float:
return ncclFloat32;
case at::ScalarType::Double:
return ncclFloat64;
case at::ScalarType::Byte:
return ncclUint8;
case at::ScalarType::Char:
return ncclInt8;
case at::ScalarType::Int:
return ncclInt32;
case at::ScalarType::Long:
return ncclInt64;
case at::ScalarType::BFloat16:
return ncclBfloat16;
default:
assert(false);
}
}
public:
NcclCommWrapper()
{
memset(&comm, 0, sizeof(ncclComm_t));
rank = 0;
world_size = 0;
}
NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks)
{
ncclCommInitRank(&comm, num_ranks, id, my_rank);
rank = my_rank;
world_size = num_ranks;
}
~NcclCommWrapper()
{
printf("ncclCommDestroy()\n");
ncclCommDestroy(comm);
}
void send(at::Tensor input, int destination)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclSend(input_ptr, count, ncclType, destination, comm, at::cuda::getCurrentCUDAStream());
});
}
void recv(at::Tensor input, int sender)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclRecv(input_ptr, count, ncclType, sender, comm, at::cuda::getCurrentCUDAStream());
});
}
void left_right_halo_exchange_inplace(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks.
int group_rank = rank % group_size;
int group_index = rank / group_size;
int prev_rank = (group_rank + group_size - 1) % group_size;
int next_rank = (group_rank + 1) % group_size;
prev_rank = prev_rank + group_index * group_size;
next_rank = next_rank + group_index * group_size;
size_t left_n = torch::numel(left_output_halo);
size_t right_n = torch::numel(right_output_halo);
if (group_rank > 0) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() {
// send left (to my_rank - 1)
ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, prev_rank, comm, stream);
// receive left (from my_rank - 1)
ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, prev_rank, comm, stream);
});
}
if (group_rank < group_size-1) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() {
// send right (to my_rank + 1 )
ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, next_rank, comm, stream);
// receive right (from my_rank + 1)
ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, next_rank, comm, stream);
});
}
ncclGroupEnd();
if (left_zero) left_input_halo.zero_();
if (right_zero) right_input_halo.zero_();
}
std::vector<at::Tensor> left_right_halo_exchange(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return {left_input_halo, right_input_halo};
}
};
std::vector<NcclCommWrapper> nccl_comms;
} // end anonymous namespace
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
auto id_tensor = torch::empty({n*(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));
auto id_ptr = id_tensor.data_ptr<uint8_t>();
size_t offset = 0;
for (int i = 0; i < n; ++i)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId));
offset += sizeof(ncclUniqueId);
}
return id_tensor;
}
int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
{
ncclUniqueId id;
auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();
memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));
NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);
int handle = nccl_comms.size();
nccl_comms.push_back(*comm);
comm = 0L;
return handle;
}
void nccl_send(int handle, at::Tensor input, int destination)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.send(input, destination);
}
void nccl_recv(int handle, at::Tensor input, int sender)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.recv(input, sender);
}
void left_right_halo_exchange_inplace(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange(left_zero, right_zero, left_output_halo, right_output_halo, group_size);
}
void add_delay(int delay)
{
auto stream = at::cuda::getCurrentCUDAStream();
auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr<int>());
}
}}}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n);
int init_nccl_comm(
at::Tensor unique_nccl_id,
int my_rank,
int num_ranks
);
void nccl_send(
int handle,
at::Tensor input,
int destination
);
void nccl_recv(
int handle,
at::Tensor input,
int sender
);
void left_right_halo_exchange_inplace(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
at::Tensor right_input_halo,
int group_size);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
int group_size
);
void add_delay(int delay);
}}}
#endif
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float");
m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int");
m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace {
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template<class T>
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
{
size_t size = 1;
std::vector<int64_t> strides(shape.size());
if (channels_last) {
assert(shape.size() == 4);
strides[0] = shape[1]*shape[2]*shape[3];
strides[1] = 1;
strides[2] = shape[1]*shape[3];
strides[3] = shape[1];
} else {
int idx = strides.size();
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
{
strides[--idx] = size;
size *= *it;
}
}
size *= sizeof(T);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options);
}
void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W)
{
if (t.dim() == 3) {
N = 1;
if (explicit_nhwc) {
C = t.size(2);
H = t.size(0);
W = t.size(1);
} else {
C = t.size(0);
H = t.size(1);
W = t.size(2);
}
} else if (t.dim() == 4) {
if (explicit_nhwc) {
N = t.size(0);
C = t.size(3);
H = t.size(1);
W = t.size(2);
} else {
N = t.size(0);
C = t.size(1);
H = t.size(2);
W = t.size(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W)
{
if (t.dim() == 3) {
if (explicit_nhwc) {
stride_C = t.stride(2);
stride_H = t.stride(0);
stride_W = t.stride(1);
} else {
stride_C = t.stride(0);
stride_H = t.stride(1);
stride_W = t.stride(2);
}
stride_N = t.size(0)*t.size(1)*t.size(2);
} else if (t.dim() == 4) {
if (explicit_nhwc) {
stride_N = t.stride(0);
stride_C = t.stride(3);
stride_H = t.stride(1);
stride_W = t.stride(2);
} else {
stride_N = t.stride(0);
stride_C = t.stride(1);
stride_H = t.stride(2);
stride_W = t.stride(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
template<class T, bool is_HWC>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
const int NC, const int NH, const int NW
)
{
size_t tot_num_threads = gridDim.x * blockDim.x;
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const size_t count = NC*NH*NW;
for (size_t i = thread_id; i < count; i += tot_num_threads)
{
size_t c,h,w;
if (is_HWC) {
c = i % NC;
w = i / NC;
h = w / NW;
w = w % NW;
}
else {
w = i % NW;
h = i / NW;
c = h / NH;
h = h % NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
)
{
cg::this_grid().sync();
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
// flush all writes to global memory
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
do {
do {
if (!top_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
btm_done = true;
}
} while (!top_done || !btm_done);
}
}
__device__ void wait_for(
volatile int* wait_flag,
const int v1, const int v2, const int v3, const int v4
)
{
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
}
template<class T, bool is_HWC>
#if __CUDA_ARCH__ >= 700
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
// top halo,
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
// btm halo
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
// dimensions
int NC, int NH, int NW,
// signals
int* signal1_flag,
int* signal2_flag,
int* wait1_flag,
int* wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
// pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
{
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay_nanoseconds);
*counter = new_counter;
}
}
}
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size)
{
float* ptr = 0L;
cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr;
}
void free_raw(int64_t raw)
{
cudaFree((void*)raw);
}
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw)
{
cudaIpcMemHandle_t mem_handle;
CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) );
const int n = sizeof(cudaIpcMemHandle_t);
auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8));
auto address_tensor_p = address_tensor.data_ptr<uint8_t>();
memcpy(address_tensor_p, (uint8_t*)&mem_handle, n);
return address_tensor;
}
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw)
{
int peer_group_size = ipc_addresses.size(0);
std::vector<int64_t> results(peer_group_size);
for (int i = 0; i < peer_group_size; ++i) {
if (i != peer_rank) {
cudaIpcMemHandle_t mem_handle;
memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr<uint8_t>(), sizeof(cudaIpcMemHandle_t));
void* p = 0L;
CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) );
results[i] = (int64_t)p;
} else {
results[i] = (int64_t)raw;
}
}
return results;
}
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
}
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK(top_out_halo.is_cuda());
TORCH_CHECK(top_out_tx.is_cuda());
TORCH_CHECK(top_inp_tx.is_cuda());
TORCH_CHECK(top_inp_halo.is_cuda());
TORCH_CHECK(btm_out_halo.is_cuda());
TORCH_CHECK(btm_out_tx.is_cuda());
TORCH_CHECK(btm_inp_tx.is_cuda());
TORCH_CHECK(btm_inp_halo.is_cuda());
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
int tox_N, tox_C, tox_H, tox_W;
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
int tix_N, tix_C, tix_H, tix_W;
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
int tih_N, tih_C, tih_H, tih_W;
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
TORCH_CHECK(
(toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
(toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
(toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
(toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
int boh_N, boh_C, boh_H, boh_W;
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
int box_N, box_C, box_H, box_W;
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
int bix_N, bix_C, bix_H, bix_W;
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
int bih_N, bih_C, bih_H, bih_W;
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
TORCH_CHECK(
(boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
(boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
(boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
(boh_W == box_W && box_W == bix_W && bix_W == bih_W));
TORCH_CHECK(
(toh_N == boh_N) &&
(toh_C == boh_C) &&
(toh_H == boh_H) &&
(toh_W == boh_W));
int NC=toh_C, NH=toh_H, NW=toh_W;
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
// determine if nhwc
auto is_nhwc = (toh_stride_C == 1) ? true : false;
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
// figure out launch parameters
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
assert(numSM > 0 && numSM <= prop.multiProcessorCount);
auto current_stream = at::cuda::getCurrentCUDAStream();
const int numThreads = 128;
dim3 block(numThreads,1,1);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
if (diagnostics) printf("size(scalar_t) = %ld\n",sizeof(scalar_t));
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
if (diagnostics) printf("waypoint1\n");
int* top_signal_p = top_signal.data_ptr<int>() + 4;
int* btm_signal_p = btm_signal.data_ptr<int>();
int* top_wait_p = waits.data_ptr<int>();
int* btm_wait_p = waits.data_ptr<int>() + 4;
if (diagnostics) printf("waypoint2\n");
// do int4 vector loads if channel count permits
int elem_size_in_bytes = toh_C * sizeof(scalar_t);
int elem_size_in_int4 = (elem_size_in_bytes / 16);
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
// can do int4 transfers
int divisor = toh_C / elem_size_in_int4;
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
NC /= divisor;
if (diagnostics) {
printf("divisor=%d\n",divisor);
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
}
void *kernelArgs[] = {
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
void *kernelArgs[] = {
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
}
}
} );
}
} } }
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
);
} } }
#endif
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
import torch
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory_cuda as pm
# How to run:
# torchrun --nproc_per_node <num-GPU> <this-python-prog>
# <num-GPU> must be a power of 2 greater than 1.
# Output of this function is used as ground truth in module tests.
def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):
if explicit_nhwc:
if H_split:
_, Hp, _, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,half_halo:2*half_halo,:,:]
top_inp_halo = y[:,:half_halo,:,:]
btm_out_halo = y[:,H:H+half_halo,:,:]
btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:]
else:
_, _, Wp, _ = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,W:W+half_halo,:]
btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:]
else:
if H_split:
_, _, Hp, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,H:H+half_halo,:]
btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:]
else:
_, _, _, Wp = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,:,half_halo:2*half_halo]
top_inp_halo = y[:,:,:,:half_halo]
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format)
btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format)
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
top_inp_halo.copy_(btm_inp_halos[top_rank])
btm_inp_halo.copy_(top_inp_halos[btm_rank])
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
if H_split:
y = torch.randn([1,H+2*half_halo,W,C], dtype=dtype, device='cuda')
ym = y[:,half_halo:H+half_halo,:,:]
else:
y = torch.randn([1,H,W+2*half_halo,C], dtype=dtype, device='cuda')
ym = y[:,:,half_halo:W+half_halo,:]
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
if H_split:
y = torch.randn([1,C,H+2*half_halo,W], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,half_halo:H+half_halo,:]
else:
y = torch.randn([1,C,H,W+2*half_halo], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo]
y3 = y.clone()
list_y = []
for step in range(num_steps):
halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0:
if memory_format == 1:
memory_format_str = "explicit_nhwc"
elif memory_format == 2:
memory_format_str = "native nhwc"
elif memory_format == 3:
memory_format_str = "nchw"
else:
memory_format_str = "???"
if is_equal:
print("SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
else:
print("FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
# peer memory flag sync relies on there being at least one barrier per step
torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main():
# for this trivial example peer_rank == rank and peer_group_size == world_size
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__":
main()
import torch
from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.half_halo = half_halo
def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False):
channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
)
import torch
import numpy as np
import peer_memory_cuda as pm
class PeerMemoryPool(object):
def __init__(self, rank, world_size, peer_group_size, static_size, dynamic_size):
self.peer_group = rank // peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_group_size = peer_group_size
self.alignment = 256
self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory
self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl
raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
self.peer_raw = pm.get_raw_peers(peer_raw_ipcs, self.peer_rank, self.raw)
self.static_offset = 0
self.dynamic_offset = 0
def __del__(self):
pm.free_raw(self.raw)
def reset(self):
self.dynamic_offset = 0
def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
nels = np.prod(shape)
if dtype == torch.float16:
elem_size = 2
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.float32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.int32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]
else:
assert(False), "dtype %s not supported" % (str(dtype))
...@@ -652,6 +652,34 @@ if "--fast_bottleneck" in sys.argv: ...@@ -652,6 +652,34 @@ if "--fast_bottleneck" in sys.argv:
) )
) )
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
raise_if_cuda_home_none("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
sources=[
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
sources=[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--fused_conv_bias_relu" in sys.argv: if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu") sys.argv.remove("--fused_conv_bias_relu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment