Unverified Commit fed20d2a authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1340 from NVIDIA/peer_memory

Peer memory halo exchange
parents d89f5e66 5698eeeb
from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
import functools as func
import torch
import torch.distributed as dist
from torch import nn
import fast_bottleneck
import nccl_p2p_cuda as inc
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
class FrozenBatchNorm2d(torch.nn.Module):
def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias):
scale = weight * running_var.rsqrt()
bias = bias - running_mean * scale
w_scale.copy_(scale)
w_bias.copy_(bias)
def compute_scale_bias_method(nhwc, args):
for arg in args:
# arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias)
compute_scale_bias_one(nhwc, *arg)
class FrozenBatchNorm2d(torch.jit.ScriptModule):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
"""
......@@ -18,7 +31,9 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def get_scale_bias(self, nhwc=False):
@torch.jit.script_method
def get_scale_bias(self, nhwc):
# type: (bool) -> List[torch.Tensor]
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
if nhwc:
......@@ -29,11 +44,11 @@ class FrozenBatchNorm2d(torch.nn.Module):
bias = bias.reshape(1, -1, 1, 1)
return scale, bias
@torch.jit.script_method
def forward(self, x):
scale, bias = self.get_scale_bias()
scale, bias = self.get_scale_bias(False)
return x * scale + bias
@torch.jit.script
def drelu_dscale1(grad_o, output, scale1):
relu_mask = (output>0)
......@@ -147,6 +162,7 @@ class Bottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
......@@ -170,23 +186,47 @@ class Bottleneck(torch.nn.Module):
for p in self.parameters():
with torch.no_grad():
p.data = p.data.permute(0,2,3,1).contiguous()
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
else:
out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv)
return out
if self.explicit_nhwc:
......@@ -217,7 +257,12 @@ class Bottleneck(torch.nn.Module):
class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, spatial_group_size, local_rank, comm, stream1, nhwc, stride_1x1, scale, bias, x, *conv):
def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv):
if spatial_group_size > 1:
stream1 = spatial_halo_exchanger.stream1
stream2 = spatial_halo_exchanger.stream2
stream3 = spatial_halo_exchanger.stream3
# TODO: clean up order of tensors
args = [x, *conv[0:3], *scale[0:3], *bias[0:3]]
ctx.downsample = len(conv) > 3
......@@ -226,59 +271,152 @@ class SpatialBottleneckFunction(torch.autograd.Function):
args.append(scale[3])
args.append(bias[3])
# weight buffers are always in nhwc while shape can be nhwc or channels_last
# weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last
# here we pass in flag and let c++ handle it
# alternatively, we can put all sizes into a fixed format and pass it in
outputs = fast_bottleneck.forward_init(nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_out2(nhwc, stride_1x1, args, outputs)
outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args)
fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs)
# do halo exchange for outputs[0] (out1)
# compute halo cells for outputs[1]
if spatial_group_size > 1:
out1 = outputs[0]
N,Hs,W,C = list(out1.shape)
if explicit_nhwc:
N,Hs,W,C = list(out1.shape)
memory_format = torch.contiguous_format
out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda')
else:
N,C,Hs,W = list(out1.shape)
memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream())
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream1):
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=out1.dtype,device=out1.device)
send_halos[:,:1,:,:].copy_(out1[:,:1,:,:])
send_halos[:,1:,:,:].copy_(out1[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*spatial_group_size,W,C),dtype=out1.dtype,device=out1.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(spatial_group_size)]
dist.all_gather(all_halos,send_halos,group=comm)
fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_out1_halo = all_halos[(spatial_group_size+local_rank-1)%spatial_group_size][:,1:,:,:]
if local_rank > 0:
fat_halo[:,:1,:,:].copy_(top_out1_halo)
fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
top_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
btm_out1_halo = all_halos[(local_rank+1)%spatial_group_size][:,:1,:,:]
if local_rank < spatial_group_size-1:
fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(nhwc, fat_halo, args)
if explicit_nhwc:
top_out1_halo = out1_pad[:,:1,:,:]
btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo)
else:
top_out1_halo = out1_pad[:,:,:1,:]
btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:]
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1:
# overlap mid convolution with halo transfer
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
if explicit_nhwc:
btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo)
else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if use_delay_kernel: inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
if spatial_group_size <= 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
elif spatial_method == 1:
fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
# compute halo cells for outputs[1] (out2)
if spatial_group_size > 1:
out2 = outputs[1]
if local_rank > 0:
out2[:,:1,:,:].copy_(top_out2)
if local_rank < spatial_group_size-1:
out2[:,Hs-1:,:,:].copy_(btm_out2)
if explicit_nhwc:
top_out2_halo = out2[:,:1,:,:]
btm_out2_halo = out2[:,Hs-1:,:,:]
else:
top_out2_halo = out2[:,:,:1,:]
btm_out2_halo = out2[:,:,Hs-1:,:]
if spatial_method == 1:
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) # wait for halo transfers to finish
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2):
w1by3 = args[2][:,2:3,:,:].clone()
btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(nhwc, stride_1x1, args, outputs)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
ctx.save_for_backward(*(args+outputs+[top_out1_halo,btm_out1_halo]))
if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else:
ctx.save_for_backward(*(args+outputs))
# save relu outputs for drelu
ctx.nhwc = nhwc
ctx.explicit_nhwc = explicit_nhwc
ctx.stride_1x1 = stride_1x1
ctx.spatial_group_size = spatial_group_size
ctx.local_rank = local_rank
ctx.comm = comm
ctx.stream1 = stream1
if spatial_group_size > 1:
ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method
ctx.use_delay_kernel = use_delay_kernel
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1
ctx.stream2 = stream2
ctx.stream3 = stream3
return outputs[2]
# backward relu is not exposed, MUL with mask used now
......@@ -286,9 +424,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_o):
if ctx.spatial_group_size > 1:
top_out1_halo = ctx.saved_tensors[-2]
btm_out1_halo = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-5:-2]
out1_pad = ctx.saved_tensors[-1]
outputs = ctx.saved_tensors[-4:-1]
else:
outputs = ctx.saved_tensors[-3:]
......@@ -310,58 +447,79 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if ctx.downsample:
t_list.append(ctx.saved_tensors[10])
grads = fast_bottleneck.backward_init(ctx.nhwc, ctx.stride_1x1, t_list)
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.nhwc, ctx.stride_1x1, t_list, grads)
# compute wgrad2 for internal cells
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# apply wgrad2 halos
if ctx.spatial_group_size > 1:
if ctx.local_rank > 0:
top_grad2_halo = grad_out2[:,:1,:,:]
top_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, top_out1_halo, top_grad2_halo)
wgrad2[:,:1,:,:].add_(top_wgrad2_halo)
if ctx.local_rank < ctx.spatial_group_size-1:
btm_grad2_halo = grad_out2[:,-1:,:,:]
btm_wgrad2_halo = fast_bottleneck.backward_wgrad2_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, btm_out1_halo, btm_grad2_halo)
wgrad2[:,-1:,:,:].add_(btm_wgrad2_halo)
grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list)
wgrad3_stream = torch.cuda.Stream()
wgrad3_stream.wait_stream(torch.cuda.current_stream())
grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
wgrad2_stream = torch.cuda.Stream()
wgrad2_stream.wait_stream(torch.cuda.current_stream())
# do halo exchange of grad_out2 here
# compute halo cells for grad_out1
if ctx.spatial_group_size > 1:
N,Hs,W,C = list(grad_out2.shape)
if ctx.explicit_nhwc:
N,Hs,W,C = list(grad_out2.shape)
else:
N,C,Hs,W = list(grad_out2.shape)
relu1 = t_list[12]
ctx.stream1.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer
send_halos = torch.empty((N,2,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
send_halos[:,:1,:,:].copy_(grad_out2[:,:1,:,:])
send_halos[:,1:,:,:].copy_(grad_out2[:,Hs-1:,:,:])
all_halos = torch.empty((N,2*ctx.spatial_group_size,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
all_halos = [all_halos[:,i*2:(i+1)*2,:,:] for i in range(ctx.spatial_group_size)]
dist.all_gather(all_halos,send_halos,group=ctx.comm)
relu1 = t_list[12]
fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.local_rank > 0:
top_halo = all_halos[ctx.local_rank-1][:,1:,:,:]
fat_halo[:,:1,:,:].copy_(top_halo)
fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
relu_halo[:,:1,:,:].zero_()
relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
if ctx.local_rank < ctx.spatial_group_size-1:
btm_halo = all_halos[ctx.local_rank+1][:,:1,:,:]
fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
fat_halo[:,2:,:,:].copy_(btm_halo)
relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
relu_halo[:,2:,:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.nhwc, ctx.stride_1x1, t_list, grads, fat_halo, relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
# 1 -> halo recompute approach
# 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
ctx.stream2.wait_stream(ctx.stream1)
with torch.cuda.stream(ctx.stream2):
if ctx.explicit_nhwc:
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_fat_halo[:,2:,:,:].copy_(btm_halo)
btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_fat_relu_halo[:,2:,:,:].zero_()
else:
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
btm_fat_relu_halo[:,:,2:,:].zero_()
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
if ctx.explicit_nhwc:
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
else:
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
if ctx.spatial_group_rank > 0:
with torch.cuda.stream(ctx.stream1):
if ctx.explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:1,:,:].copy_(top_halo)
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:1,:,:].zero_()
top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo[:,:,:1,:].copy_(top_halo)
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_relu_halo[:,:,:1,:].zero_()
top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
if ctx.explicit_nhwc:
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
if ctx.use_delay_kernel: inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
# compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
# apply halo cells to grad_out1
if ctx.spatial_group_size > 1:
......@@ -369,17 +527,69 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4]
relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.local_rank > 0:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
if ctx.local_rank < ctx.spatial_group_size-1:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
#print("ctx.local_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.local_rank, str(list(grad_out1.shape))))
fast_bottleneck.backward_rest(ctx.nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
if ctx.spatial_method == 1 or ctx.spatial_method == 2:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2)
if ctx.explicit_nhwc:
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
else:
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
if ctx.explicit_nhwc:
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
else:
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
else:
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
w1by3 = w[:,:1,:,:].clone()
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream2):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:,:1,:,:].clone()
top_grad_out1 = grad_out1[:,:1,:,:]
else:
top_relu_halo = relu1[:,:,:1,:].clone()
top_grad_out1 = grad_out1[:,:,:1,:]
w1by3 = w[:,2:,:,:].clone()
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
wgrad1_stream = torch.cuda.Stream()
wgrad1_stream.wait_stream(torch.cuda.current_stream())
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1)
with torch.cuda.stream(wgrad3_stream):
fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads)
with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1:
fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2)
else:
fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
with torch.cuda.stream(wgrad1_stream):
fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1)
torch.cuda.current_stream().wait_stream(wgrad3_stream)
torch.cuda.current_stream().wait_stream(wgrad2_stream)
torch.cuda.current_stream().wait_stream(wgrad1_stream)
return (None, None, None, None, None, None, None, None, *grads)
return (None, None, None, None, None, None, None, None, None, None, None, None, *grads)
spatial_bottleneck_function = SpatialBottleneckFunction.apply
......@@ -393,7 +603,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1, communicator=None):
spatial_parallel_args=None):
super(SpatialBottleneck, self).__init__()
if groups != 1:
raise RuntimeError('Only support groups == 1')
......@@ -422,6 +632,7 @@ class SpatialBottleneck(torch.nn.Module):
self.bn1 = norm_func(bottleneck_channels)
self.bn2 = norm_func(bottleneck_channels)
self.bn3 = norm_func(out_channels)
self.w_scale = None
self.use_cudnn = use_cudnn
......@@ -434,6 +645,8 @@ class SpatialBottleneck(torch.nn.Module):
for w in self.w_conv:
kaiming_uniform_(w, a=1)
self.thresholdTop, self.thresholdBottom = None, None
# TODO: prevent unsupported case usage
# support cases
# native cudnn
......@@ -447,43 +660,59 @@ class SpatialBottleneck(torch.nn.Module):
p.data = p.data.permute(0,2,3,1).contiguous()
# spatial communicator
self.spatial_group_size = spatial_group_size
if spatial_group_size > 1:
world_size = dist.get_world_size()
num_groups = world_size // spatial_group_size
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank()
self.local_rank = rank % spatial_group_size
if communicator is None:
for group in range(num_groups):
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
comm = torch.distributed.new_group(ranks=ranks)
if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
if spatial_parallel_args is None:
self.spatial_parallel_args = (1, 0, None, None, 0, False)
else:
self.spatial_args = 1, 0, None, None
self.spatial_parallel_args = spatial_parallel_args
return
# Returns single callable that recomputes scale and bias for all frozen batch-norms.
# This method must be called before cuda graphing.
# The callable it returns can be called anytime.
# Calling this method will prevent these from being computed every forward call.
def get_scale_bias_callable(self):
self.w_scale, self.w_bias, args = [], [], []
batch_norms = [self.bn1, self.bn2, self.bn3]
if self.downsample is not None:
batch_norms.append(self.downsample[1])
for bn in batch_norms:
s = torch.empty_like(bn.weight)
b = torch.empty_like(s)
args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) )
if self.explicit_nhwc:
self.w_scale.append( s.reshape(1, 1, 1, -1) )
self.w_bias.append( b.reshape(1, 1, 1, -1) )
else:
self.w_scale.append( s.reshape(1, -1, 1, 1) )
self.w_bias.append( b.reshape(1, -1, 1, 1) )
return func.partial(compute_scale_bias_method, self.explicit_nhwc, args)
def forward(self, x):
if self.use_cudnn:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_args, self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv)
if self.thresholdTop is None:
spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args
if self.explicit_nhwc:
N,H,W,C = list(x.shape)
else:
N,C,H,W = list(x.shape)
self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda')
self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda')
if self.w_scale is None:
# calculate scale/bias from registered buffers
# TODO: make this better
s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc)
s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc)
s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc)
w_scale = [s1, s2, s3]
w_bias = [b1, b2, b3]
if self.downsample is not None:
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
return out
if self.explicit_nhwc:
......@@ -510,3 +739,4 @@ class SpatialBottleneck(torch.nn.Module):
out = self.relu(out)
return out
import os
import torch
from maskrcnn_benchmark.modeling.backbone.resnet import Bottleneck
from maskrcnn_benchmark.layers.nhwc import nhwc_to_nchw_transform, nchw_to_nhwc_transform
from maskrcnn_benchmark.layers.nhwc.batch_norm import FrozenBatchNorm2d_NHWC
from apex.contrib.bottleneck import Bottleneck as FastBottleneck
from apex.contrib.bottleneck import SpatialBottleneck
from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck
from apex.contrib.bottleneck import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
from apex.contrib.peer_memory import PeerMemoryPool
def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spatial_group_size, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc):
# inputs + modules
def ground_truth_bottleneck(C, dtype, explicit_nhwc):
bottleneck = Bottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc)
bottleneck.to(dtype=dtype, device='cuda')
for p in bottleneck.parameters():
torch.distributed.broadcast(p, 0)
for b in bottleneck.buffers():
torch.distributed.broadcast(b, 0)
return bottleneck
def print_bottleneck_p_and_b(bottleneck):
with torch.no_grad():
for n,p in bottleneck.named_parameters():
print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
for n,p in bottleneck.named_buffers():
print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
def has_nan(x):
if isinstance(x, list) or isinstance(x, tuple):
for xx in x:
if torch.any(torch.isnan(xx)):
return True
return False
elif isinstance(x, dict):
for k,v in x.items():
if torch.any(torch.isnan(v)):
return True
else:
return torch.any(torch.isnan(x))
def rel_diff_t(xx1, xx2):
return ((xx1 - xx2).norm(p=2,dtype=torch.float32) / (xx1 + xx2).norm(p=2,dtype=torch.float32)).item()
def rel_diff(x1, x2):
if isinstance(x1, list) or isinstance(x1, tuple):
return [rel_diff_t(xx1,xx2) for xx1,xx2 in zip(x1,x2)]
elif isinstance(x1, dict):
return [rel_diff_t(xx1, xx2) for (k1,xx1), (k2,xx2) in zip(x1.items(),x2.items())]
else:
return rel_diff_t(x1,x2)
def graph_it(bottleneck, x):
print("Graphing")
with torch.no_grad():
input_shape = [1, in_channels] + list(shape)
x = torch.randn(input_shape, dtype=numtype, device=device)
if nhwc:
x = nchw_to_nhwc_transform(x).contiguous()
x = x.clone()
x.grad = None
x.requires_grad = True
print(x.shape, x.stride())
#if spatial_group_size > 1:
# fast = False # hack so fast bottleneck can be run against distributed bottleneck
#if spatial_group_size == 1:
# fast = False
if fast:
if spatial_group_size == 1:
bottleneck = FastBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
else:
bottleneck = SpatialBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True,
spatial_group_size=spatial_group_size)
else:
bottleneck = Bottleneck(
in_channels,
bottleneck_channels,
out_channels,
num_groups,
stride_in_1x1,
stride,
dilation,
norm_func,
nhwc,
spatial_group_size)
bottleneck = bottleneck.to(dtype=numtype,device=device)
weights = dict(bottleneck.named_parameters())
if ref is not None:
ref_x, _, ref_weights = ref
Hs,H = x.shape[1], ref_x.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_x = ref_x[:,rank*Hs:(rank+1)*Hs,:,:]
x.copy_(ref_x)
assert(len(weights) == len(ref_weights)), "Reference weights and weights don't match"
for k in weights.keys():
weights[k].copy_(ref_weights[k])
# forward
out = bottleneck(x)
# gradient output
return torch.cuda.make_graphed_callables(bottleneck, (x,))
def clone_inputs(bottleneck, x, dy=None):
with torch.no_grad():
grad_out = torch.randn_like(out)
if ref is not None:
_, ref_grad_out, _ = ref
Hs,H = grad_out.shape[1], ref_grad_out.shape[1]
assert(Hs*spatial_group_size == H), "Hs not a multiple of H"
ref_grad_out = ref_grad_out[:,rank*Hs:(rank+1)*Hs,:,:]
grad_out.copy_(ref_grad_out)
x = x.clone()
x.grad = None
x.requires_grad = True
if dy is None:
y = bottleneck(x)
dy = torch.randn_like(y) / 1e2
torch.distributed.broadcast(dy, 0)
return x, dy
def fprop_and_bprop(bottleneck, x, dy):
y = bottleneck(x)
y.backward(dy)
dgrad = x.grad.detach()
wgrad = {}
for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach()
return x, y, dy, dgrad, wgrad
def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
with torch.no_grad():
x = torch.randn([N,H,W,C], dtype=dtype, device='cuda')
torch.distributed.broadcast(x, 0)
x, dy = clone_inputs(bottleneck, x)
return fprop_and_bprop(bottleneck, x, dy)
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
assert(False), "Not implemented yet"
def print_ground_truth(gt):
x, y, dy, dgrad, wgrad = gt
if has_nan(y) or has_nan(dgrad) or has_nan(wgrad):
print("Error! Ground truth has NAN")
else:
print("Ok! No NAN found in ground truth")
# backward
out.backward(grad_out)
def apply_to_different_bottleneck(gt, bottleneck):
with torch.no_grad():
dgrad = x.grad.detach()
wgrad = {}
for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach()
if world_size > 1:
if spatial_group_size == 1:
# broadcast x, grad_out and weights from rank 0
with torch.no_grad():
torch.distributed.broadcast(x,0)
torch.distributed.broadcast(grad_out,0)
for k in weights.keys():
torch.distributed.broadcast(weights[k],0)
x, _, dy, _, _ = gt
x, dy = clone_inputs(bottleneck, x, dy)
return fprop_and_bprop(bottleneck, x, dy)
def compare_single_field(results, f1, f2, l0, l1, l2):
if has_nan(f1) and has_nan(f2):
results[l0] = "both NAN"
elif has_nan(f1):
results[l0] = "%s.%s NAN" % (l1, l0)
elif has_nan(f2):
results[l0] = "%s.%s NAN" % (l2, l0)
else:
results[l0] = "%s" % (str(rel_diff(f1,f2)))
def compare(gt, bt):
x1, y1, dy1, dgrad1, wgrad1 = gt
x2, y2, dy2, dgrad2, wgrad2 = bt
results = {}
compare_single_field(results, y1, y2, "y", "gt", "bt")
compare_single_field(results, dy1, dy2, "dy", "gt", "bt")
compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt")
compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt")
for i in range(torch.distributed.get_world_size()):
if i == torch.distributed.get_rank():
print(i,results)
torch.distributed.barrier()
def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args):
spatial_bottleneck = SpatialBottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc,spatial_parallel_args=spatial_parallel_args)
spatial_bottleneck.to(dtype=dtype, device='cuda')
with torch.no_grad():
sp = {}
for n,p in spatial_bottleneck.named_parameters():
sp[n] = p
for n,p in gt_bottleneck.named_parameters():
sp[n].copy_(p)
sb = {}
for n,b in spatial_bottleneck.named_buffers():
sb[n] = b
for n,b in gt_bottleneck.named_buffers():
sb[n].copy_(b)
return spatial_bottleneck
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):
assert(explicit_nhwc), "Only tested for explicit nhwc"
x, _, dy, _, _ = gt
N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel
dtype = x.dtype
spatial_group_size = world_size
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel = False
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
with torch.no_grad():
Hs = H // spatial_group_size
xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
xs.requires_grad = True
spatial_bottleneck = graph_it(spatial_bottleneck, xs)
_, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)
# gather output pieces
for n,p in wgrad.items():
if fp32_reduce:
p32 = p.float()
torch.distributed.all_reduce(p32)
p.copy_(p32.half())
else:
# gather dgrad (x.grad), sum wgrad (weights) and out
N,Hs,W,C = dgrad.shape
H = Hs * spatial_group_size
dgrad_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
dgrad_tensors = [dgrad_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(dgrad_tensors, dgrad)
dgrad = dgrad_gathered
N,Hs,W,C = list(out.shape)
H = Hs * spatial_group_size
out_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
out_tensors= [out_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(out_tensors, out)
out = out_gathered
for k in wgrad.keys():
w = wgrad[k].to(dtype=torch.float64)
torch.distributed.all_reduce(w)
wgrad[k].copy_(w.to(dtype=wgrad[k].dtype))
#torch.distributed.all_reduce(wgrad[k])
return x, out, grad_out, weights, dgrad, wgrad
def module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args):
r = []
for ia in init_args:
shape = ia[0:4]
args = ia[4:]
rr = []
ref = None
for spatial_group_size in spatial_group_sizes:
N,H,W,C = shape
H = H//spatial_group_size
x, out, grad_out, weights, dgrad, wgrad = single_module_test(ref, rank, world_size, numtype, device, [H,W], fast, spatial_group_size, *args)
if ref is None:
assert(spatial_group_size == 1), "Wrong reference weights"
ref = x, grad_out, weights
if rank == 0:
rr.append( (out, dgrad, wgrad) )
if world_size > 1: torch.distributed.barrier()
r.append(rr)
return r
torch.distributed.all_reduce(p)
ys = [torch.empty_like(y) for _ in range(spatial_group_size)]
torch.distributed.all_gather(ys,y)
y = torch.cat(ys,dim=1)
dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)]
torch.distributed.all_gather(dgrads,dgrad)
dgrad = torch.cat(dgrads,dim=1)
return x, y, dy, dgrad, wgrad
def main():
total_num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
distributed = total_num_gpus > 1
ngpus = torch.cuda.device_count()
if distributed:
torch.distributed.init_process_group("nccl")
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
is_master = True if rank == 0 else False
local_rank = rank % ngpus
torch.cuda.set_device(local_rank)
spatial_group_size = total_num_gpus
else:
rank, local_rank, is_master, world_size, spatial_group_size = 0, 0, True, 1, 1
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
norm_func = FrozenBatchNorm2d_NHWC
init_args = [
(1, 200, 336, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 200, 336, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 100, 168, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 100, 168, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 50, 84, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 25, 42, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 64, 64, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 64, 256, 1, True, 1, 1, norm_func, True),
(1, 336, 200, 256, 256, 128, 512, 1, True, 2, 1, norm_func, True),
(1, 168, 100, 512, 512, 128, 512, 1, True, 1, 1, norm_func, True),
(1, 168, 100, 512, 512, 256, 1024, 1, True, 2, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 256, 1024, 1, True, 1, 1, norm_func, True),
(1, 84, 50, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 42, 25, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
]
init_args = init_args[0:1]
# pad H to account for spatial distribution
padded_init_args = []
for ia in init_args:
N,H,W,C = ia[0:4]
m = spatial_group_size * H // (25 if H < W else 42)
H = ((H + m - 1) // m) * m
args = tuple( [N,H,W,C] + list(ia[4:]) )
padded_init_args.append(args)
init_args = padded_init_args
if rank == 0:
for ia in init_args:
print(ia)
spatial_group_sizes = [1]
if spatial_group_size > 1:
spatial_group_sizes.append(spatial_group_size)
numtype, device, fast = torch.float16, 'cuda', True
r = module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args)
if world_size > 1: torch.distributed.barrier()
if rank == 0:
for rr in r:
print("***")
for out, dgrad, wgrad in rr:
gr = [("out",out.norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
print(gr)
if len(rr) == 2:
out1, dgrad1, wgrad1 = rr[0]
out2, dgrad2, wgrad2 = rr[1]
rtol = 1e-1
out_atol = out1.abs().max().item() * rtol
dgrad_atol = dgrad1.abs().max().item() * rtol
wgrad_atol = {}
for k in wgrad1.keys():
wgrad_atol[k] = wgrad1[k].abs().max().item() * rtol
gr = [("out",torch.allclose(out1,out2,rtol,out_atol,equal_nan=True))]
gr = gr + [("dgrad",torch.allclose(dgrad1,dgrad2,rtol,dgrad_atol,equal_nan=True))]
gr = gr + [(k+".wgrad",torch.allclose(wgrad1[k],wgrad2[k],rtol,wgrad_atol[k],equal_nan=True)) for k in wgrad1.keys()]
print(gr)
gr = [("out",(out1-out2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",(dgrad1-dgrad2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",(wgrad1[k]-wgrad2[k]).norm(p=2,dtype=torch.float64).item()) for k in wgrad1.keys()]
print(gr)
N,H,W,C = out1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs-1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs+1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
N,H,W,C = dgrad1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs-1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs+1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
if world_size > 1: torch.distributed.barrier()
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
explicit_nhwc = True
dtype = torch.float16
N, C, H, W = 1, 64, 200, 336
Hs = ((H+8*world_size-1) // (8*world_size)) * 8
H = Hs*world_size
gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)
gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)
# verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None)
bt = apply_to_different_bottleneck(gt, spatial_bottleneck)
compare(gt, bt)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
spatial_group_size = world_size
spatial_communicator = None
peer_pool = PeerMemoryPool(rank, world_size, spatial_group_size, 64*1024*1024, 2*1024*1024)
#halex = HaloExchangerAllGather(world_size, spatial_group_size, rank, spatial_communicator)
#halex = HaloExchangerSendRecv(world_size, spatial_group_size, rank, spatial_communicator)
halex = HaloExchangerPeer(world_size, spatial_group_size, rank, spatial_communicator, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2)
if __name__ == "__main__":
......
import torch
import torch.distributed as dist
from torch import nn
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, spatial_group_size, rank):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
spatial_rank = rank % spatial_group_size
self.left_zero = True if spatial_rank == 0 else False
self.right_zero = True if spatial_rank == spatial_group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerNoComm, self).__init__(spatial_group_size, rank)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerAllGather, self).__init__(spatial_group_size, rank)
self.spatial_group_size = spatial_group_size
self.local_rank = rank % spatial_group_size
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.spatial_group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.spatial_group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[(self.spatial_group_size+self.local_rank-1)%self.spatial_group_size][:,Hh:,:,:]
ag_right_input_halo = all_halos[(self.local_rank+1)%self.spatial_group_size][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm):
super(HaloExchangerSendRecv, self).__init__(spatial_group_size, rank)
self.world_size = world_size
self.spatial_group_size = spatial_group_size
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
self.handle = inc.init_nccl_comm(nccl_id, rank, world_size)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, self.spatial_group_size)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_zero, self.right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, self.spatial_group_size)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, world_size, spatial_group_size, rank, comm, peer_pool, explicit_nhwc, numSM=1):
super(HaloExchangerPeer, self).__init__(spatial_group_size, rank)
self.diagnostics = False
self.spatial_group_size = spatial_group_size
self.peer_rank = rank % spatial_group_size
self.left_neighbor = (self.peer_rank + self.spatial_group_size - 1) % self.spatial_group_size
self.right_neighbor = (self.peer_rank + 1) % self.spatial_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True)
right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM,
left_output_halo, left_tx[self.peer_rank], right_tx[self.left_neighbor], left_input_halo,
right_output_halo, right_tx[self.peer_rank], left_tx[self.right_neighbor], right_input_halo,
self.signals[self.left_neighbor], self.signals[self.right_neighbor], self.signals[self.peer_rank]
)
# TODO: Add to push_pull_halos_1d kernel
if self.left_zero:
left_input_halo.zero_()
if self.right_zero:
right_input_halo.zero_()
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
......@@ -1608,18 +1608,1318 @@ std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1,
namespace {
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
AFTERACT_TENSOR,
GEN_INDEX_TENSOR,
MASK_TOP_TENSOR,
MASK_BOTTOM_TENSOR,
MASK_TENSOR,
THRESHOLD_TOP_TENSOR,
THRESHOLD_BOTTOM_TENSOR,
};
using masked_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('E') // after act for masked
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
DGRAD_INPUT_TENSOR,
DGRAD_OPTIONAL_TENSOR,
DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
};
using dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_add_descriptors
create_dconv_add_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_add_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_mask_descriptors
create_dconv_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
void
run_conv_add_scale_bias_activation(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// create an add node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<AFTERACT_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERACT_TENSOR>(tensors))
.setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setbDesc(std::get<AFTERACT_TENSOR>(tensors))
.settDesc(std::get<MASK_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
if (devPtrI) {
std::array<cudnn_frontend::Operation const*, 10> ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(8, data_ptrs)
.setUids(8, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} else {
std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
}
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_add_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_add_descriptors tensors = create_dconv_add_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
struct bottleneck_forward_status {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA2hh[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t threshdim[4];
int axis[4];
int64_t outdimA0[4];
int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t outdimA4[4];
......@@ -1633,6 +2933,7 @@ struct bottleneck_forward_status {
int64_t outdim0[4]; // halo input shape
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape
......@@ -1641,8 +2942,10 @@ struct bottleneck_forward_status {
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
......@@ -1668,10 +2971,18 @@ struct bottleneck_forward_status {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
......@@ -1690,6 +3001,13 @@ struct bottleneck_forward_status {
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
......@@ -1715,6 +3033,7 @@ struct bottleneck_forward_status {
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) {
......@@ -1726,6 +3045,7 @@ struct bottleneck_forward_status {
for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]];
......@@ -1821,6 +3141,41 @@ at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_
return halo_y2;
}
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector<at::Tensor> inputs, at::Tensor w1by3, at::Tensor out2_part_halo) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = w1by3.data_ptr<at::Half>(); // C,C,1,3
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = slim_halo_y1.data_ptr<at::Half>();
at::Half* prev_out2 = out2_part_halo.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_add_scale_bias_activation(forward_state.outdimA4,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2hh,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
prev_out2);
return halo_y2;
}
void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
......@@ -1859,6 +3214,86 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
forward_state.threshdim,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2); // axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
......@@ -1932,10 +3367,12 @@ struct bottleneck_backward_state {
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int64_t threshdim[4];
int axis[4];
int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
......@@ -1953,9 +3390,11 @@ struct bottleneck_backward_state {
int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim1h[4];
int64_t outdim1hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions
......@@ -1965,6 +3404,7 @@ struct bottleneck_backward_state {
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
......@@ -2001,6 +3441,7 @@ struct bottleneck_backward_state {
// output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
......@@ -2022,6 +3463,13 @@ struct bottleneck_backward_state {
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
......@@ -2051,9 +3499,11 @@ struct bottleneck_backward_state {
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
......@@ -2063,9 +3513,11 @@ struct bottleneck_backward_state {
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]];
outdim1hh[dim] = outdimA1hh[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]];
}
}
......@@ -2102,19 +3554,12 @@ std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_
return outputs;
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad
auto wgrad3 = outputs[3];
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
......@@ -2129,6 +3574,22 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
dw3,
dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// dgrad
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
......@@ -2178,6 +3639,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
//printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
......@@ -2194,6 +3656,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
return grad_out1;
}
at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale_mask(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
backward_state.threshdim,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2);
return grad_out1;
}
// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C]
at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, at::Tensor w1by3, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
//at::Half* w = inputs[2].data_ptr<at::Half>(); // use w1by3 instead, which is a sliced version of inputs[2]
at::Half* w = w1by3.data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_add_drelu_dscale(backward_state.outdimA1hh,
backward_state.padA2, // 0,1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // C,1,3,C
backward_state.outdimA2hh,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h,
pdy1h);
return grad_out1_halo;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
......@@ -2233,7 +3777,38 @@ at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1
return grad_out1_halo;
}
at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2262,8 +3837,7 @@ at::Tensor bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::v
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
return wgrad2;
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
......@@ -2306,7 +3880,30 @@ at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, s
return wgrad2_halo;
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1, at::Tensor wgrad2) {
void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
bool requires_grad = inputs[0].requires_grad();
......@@ -2404,22 +4001,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dx_conv4 = inputs[11].data_ptr<at::Half>();
}
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
w = inputs[1].data_ptr<at::Half>();
auto grad_x = outputs[0];
......@@ -2460,8 +4041,6 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
......@@ -2474,13 +4053,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init");
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward");
m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nccl_p2p_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id");
m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm");
m.def("nccl_send", &apex::contrib::nccl_p2p::nccl_send, "nccl_send");
m.def("nccl_recv", &apex::contrib::nccl_p2p::nccl_recv, "nccl_recv");
m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace");
m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange");
m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <ctime>
#include <cassert>
#include "nccl.h"
/*
* This file implements a crude but effective mechanism for copying data between tenors owned by different ranks
* on the same machine using cudaMemcpyAsync peer-to-peer transfers.
*/
namespace {
__global__ void AddDelay_kernel(const int delay, int* counter) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay);
*counter = new_counter;
}
}
class NcclCommWrapper
{
private:
ncclComm_t comm;
int rank, world_size;
ncclDataType_t get_nccl_type(at::Tensor input)
{
switch (input.scalar_type())
{
case at::ScalarType::Half:
return ncclFloat16;
case at::ScalarType::Float:
return ncclFloat32;
case at::ScalarType::Double:
return ncclFloat64;
case at::ScalarType::Byte:
return ncclUint8;
case at::ScalarType::Char:
return ncclInt8;
case at::ScalarType::Int:
return ncclInt32;
case at::ScalarType::Long:
return ncclInt64;
case at::ScalarType::BFloat16:
return ncclBfloat16;
default:
assert(false);
}
}
public:
NcclCommWrapper()
{
memset(&comm, 0, sizeof(ncclComm_t));
rank = 0;
world_size = 0;
}
NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks)
{
ncclCommInitRank(&comm, num_ranks, id, my_rank);
rank = my_rank;
world_size = num_ranks;
}
~NcclCommWrapper()
{
printf("ncclCommDestroy()\n");
ncclCommDestroy(comm);
}
void send(at::Tensor input, int destination)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclSend(input_ptr, count, ncclType, destination, comm, at::cuda::getCurrentCUDAStream());
});
}
void recv(at::Tensor input, int sender)
{
ncclDataType_t ncclType = get_nccl_type(input);
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "nccl_send", [&]() {
size_t count = sizeof(scalar_t) * torch::numel(input);
auto input_ptr = input.data_ptr<scalar_t>();
ncclRecv(input_ptr, count, ncclType, sender, comm, at::cuda::getCurrentCUDAStream());
});
}
void left_right_halo_exchange_inplace(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
auto stream = at::cuda::getCurrentCUDAStream();
ncclGroupStart();
ncclDataType_t ncclType = get_nccl_type(left_output_halo);
// we use wrap-around ranks, so left_input_halo of rank 0 has right_output_halo of rank world_size-1 after exchange etc.
// this is technically speaking wasteful, but there is no benefit in having the edge ranks do less work than internal ranks.
int group_rank = rank % group_size;
int group_index = rank / group_size;
int prev_rank = (group_rank + group_size - 1) % group_size;
int next_rank = (group_rank + 1) % group_size;
prev_rank = prev_rank + group_index * group_size;
next_rank = next_rank + group_index * group_size;
size_t left_n = torch::numel(left_output_halo);
size_t right_n = torch::numel(right_output_halo);
if (group_rank > 0) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() {
// send left (to my_rank - 1)
ncclSend(left_output_halo.data_ptr<scalar_t>(), left_n, ncclType, prev_rank, comm, stream);
// receive left (from my_rank - 1)
ncclRecv(left_input_halo.data_ptr<scalar_t>(), right_n, ncclType, prev_rank, comm, stream);
});
}
if (group_rank < group_size-1) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() {
// send right (to my_rank + 1 )
ncclSend(right_output_halo.data_ptr<scalar_t>(), right_n, ncclType, next_rank, comm, stream);
// receive right (from my_rank + 1)
ncclRecv(right_input_halo.data_ptr<scalar_t>(), left_n, ncclType, next_rank, comm, stream);
});
}
ncclGroupEnd();
if (left_zero) left_input_halo.zero_();
if (right_zero) right_input_halo.zero_();
}
std::vector<at::Tensor> left_right_halo_exchange(bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
// after halo exchange:
// left_output_halo of rank+1 ends up in right_input_halo of rank
// right_output_halo of rank-1 ends up in left_input_halo of rank
auto right_input_halo = torch::empty_like(left_output_halo);
auto left_input_halo = torch::empty_like(right_output_halo);
left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
return {left_input_halo, right_input_halo};
}
};
std::vector<NcclCommWrapper> nccl_comms;
} // end anonymous namespace
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
auto id_tensor = torch::empty({n*(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false));
auto id_ptr = id_tensor.data_ptr<uint8_t>();
size_t offset = 0;
for (int i = 0; i < n; ++i)
{
ncclUniqueId id;
ncclGetUniqueId(&id);
memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId));
offset += sizeof(ncclUniqueId);
}
return id_tensor;
}
int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks)
{
ncclUniqueId id;
auto unique_nccl_id_ptr = unique_nccl_id.data_ptr<uint8_t>();
memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId));
NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks);
int handle = nccl_comms.size();
nccl_comms.push_back(*comm);
comm = 0L;
return handle;
}
void nccl_send(int handle, at::Tensor input, int destination)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.send(input, destination);
}
void nccl_recv(int handle, at::Tensor input, int sender)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper communicator = nccl_comms[handle];
communicator.recv(input, sender);
}
void left_right_halo_exchange_inplace(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange_inplace(left_zero, right_zero, left_output_halo, right_output_halo, left_input_halo, right_input_halo, group_size);
}
std::vector<at::Tensor> left_right_halo_exchange(int handle, bool left_zero, bool right_zero, at::Tensor left_output_halo, at::Tensor right_output_halo, int group_size)
{
assert(handle >= 0 && handle < nccl_comms.size());
class NcclCommWrapper& communicator = nccl_comms[handle];
return communicator.left_right_halo_exchange(left_zero, right_zero, left_output_halo, right_output_halo, group_size);
}
void add_delay(int delay)
{
auto stream = at::cuda::getCurrentCUDAStream();
auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr<int>());
}
}}}
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _nccl_p2p_h_
#define _nccl_p2p_h_
namespace apex { namespace contrib { namespace nccl_p2p {
at::Tensor get_unique_nccl_id(int n);
int init_nccl_comm(
at::Tensor unique_nccl_id,
int my_rank,
int num_ranks
);
void nccl_send(
int handle,
at::Tensor input,
int destination
);
void nccl_recv(
int handle,
at::Tensor input,
int sender
);
void left_right_halo_exchange_inplace(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
at::Tensor left_input_halo,
at::Tensor right_input_halo,
int group_size);
std::vector<at::Tensor> left_right_halo_exchange(
int handle,
bool left_zero,
bool right_zero,
at::Tensor left_output_halo,
at::Tensor right_output_halo,
int group_size
);
void add_delay(int delay);
}}}
#endif
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "peer_memory_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw");
m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw");
m.def("zero", &apex::contrib::peer_memory::zero, "zero");
m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address");
m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers");
m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half");
m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float");
m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int");
m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d");
}
#include <torch/extension.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CUDAContext.h>
#include <list>
#include <cstdio>
#include <cassert>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include "nccl.h"
namespace cg = cooperative_groups;
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if( err != cudaSuccess ) { \
char hostname[1024]; \
gethostname(hostname, 1024); \
printf("%s: CUDA failure %s:%d '%s'\n", \
hostname, \
__FILE__,__LINE__,cudaGetErrorString(err)); \
} \
} while(0)
namespace {
/* Basic deleter function for from_blob function.
void deleter(void* ptr)
{
printf("deleter(ptr=%p)\n",ptr);
cudaFree(ptr);
}
*/
template<class T>
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
{
size_t size = 1;
std::vector<int64_t> strides(shape.size());
if (channels_last) {
assert(shape.size() == 4);
strides[0] = shape[1]*shape[2]*shape[3];
strides[1] = 1;
strides[2] = shape[1]*shape[3];
strides[3] = shape[1];
} else {
int idx = strides.size();
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
{
strides[--idx] = size;
size *= *it;
}
}
size *= sizeof(T);
// TODO: Implement dynamic reuse of pooled peer memory.
// We provide no deleter function because all peer memory allocations are static in this implementation.
return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options);
}
void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W)
{
if (t.dim() == 3) {
N = 1;
if (explicit_nhwc) {
C = t.size(2);
H = t.size(0);
W = t.size(1);
} else {
C = t.size(0);
H = t.size(1);
W = t.size(2);
}
} else if (t.dim() == 4) {
if (explicit_nhwc) {
N = t.size(0);
C = t.size(3);
H = t.size(1);
W = t.size(2);
} else {
N = t.size(0);
C = t.size(1);
H = t.size(2);
W = t.size(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W)
{
if (t.dim() == 3) {
if (explicit_nhwc) {
stride_C = t.stride(2);
stride_H = t.stride(0);
stride_W = t.stride(1);
} else {
stride_C = t.stride(0);
stride_H = t.stride(1);
stride_W = t.stride(2);
}
stride_N = t.size(0)*t.size(1)*t.size(2);
} else if (t.dim() == 4) {
if (explicit_nhwc) {
stride_N = t.stride(0);
stride_C = t.stride(3);
stride_H = t.stride(1);
stride_W = t.stride(2);
} else {
stride_N = t.stride(0);
stride_C = t.stride(1);
stride_H = t.stride(2);
stride_W = t.stride(3);
}
} else {
printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim());
assert(t.dim() == 3 || t.dim() == 4);
}
}
template<class T, bool is_HWC>
__device__ void strided_copy_kernel(
T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W,
const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W,
const int NC, const int NH, const int NW
)
{
size_t tot_num_threads = gridDim.x * blockDim.x;
size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const size_t count = NC*NH*NW;
for (size_t i = thread_id; i < count; i += tot_num_threads)
{
size_t c,h,w;
if (is_HWC) {
c = i % NC;
w = i / NC;
h = w / NW;
w = w % NW;
}
else {
w = i % NW;
h = i / NW;
c = h / NH;
h = h % NH;
}
size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W;
size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W;
dst[dst_off] = src[src_off];
}
}
__device__ void checked_signal(
volatile int* signal1_flag, volatile int* signal2_flag,
const int v1, const int v2, const int v3, const int v4
)
{
cg::this_grid().sync();
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
// flush all writes to global memory
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
do {
do {
if (!top_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
btm_done = true;
}
} while (!top_done || !btm_done);
}
}
__device__ void wait_for(
volatile int* wait_flag,
const int v1, const int v2, const int v3, const int v4
)
{
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
// wait for senders to signal their output is read
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory");
} while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4);
}
cg::this_grid().sync(); // all threads wait for main
}
__device__ void clear_flag(
volatile int* wait_flag
)
{
cg::this_grid().sync(); // wait for all threads in kernel to finish
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
register int r1, r2, r3, r4;
r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
}
template<class T, bool is_HWC>
#if __CUDA_ARCH__ >= 700
__launch_bounds__(128, 16)
#endif
__global__ void push_pull_halos_1d_kernel(
// top halo,
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
// btm halo
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
// dimensions
int NC, int NH, int NW,
// signals
int* signal1_flag,
int* signal2_flag,
int* wait1_flag,
int* wait2_flag
)
{
// push top output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW);
// push btm output halo to transfer buffer
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input
wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
clear_flag(wait1_flag);
// pull btm halo from transfer buffer in peer memory to input
wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358);
strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
clear_flag(wait2_flag);
}
__global__ void delay_kernel(int delay_nanoseconds, int* counter)
{
if (blockIdx.x == 0 && threadIdx.x == 0) {
// waste time while doing something compiler can't predict, thus preventing it from optimizing away this code.
int new_counter = 0;
double elapsed = 0;
clock_t start = clock();
do {
clock_t now = clock();
elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC;
++new_counter;
} while (elapsed < (double)delay_nanoseconds);
*counter = new_counter;
}
}
}
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size)
{
float* ptr = 0L;
cudaMalloc(&ptr, size);
cudaMemset(ptr, 0, size);
return (int64_t)ptr;
}
void free_raw(int64_t raw)
{
cudaFree((void*)raw);
}
void zero(int64_t raw, int64_t size)
{
cudaMemset((void*)raw, 0, size);
}
at::Tensor get_raw_ipc_address(int64_t raw)
{
cudaIpcMemHandle_t mem_handle;
CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) );
const int n = sizeof(cudaIpcMemHandle_t);
auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8));
auto address_tensor_p = address_tensor.data_ptr<uint8_t>();
memcpy(address_tensor_p, (uint8_t*)&mem_handle, n);
return address_tensor;
}
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw)
{
int peer_group_size = ipc_addresses.size(0);
std::vector<int64_t> results(peer_group_size);
for (int i = 0; i < peer_group_size; ++i) {
if (i != peer_rank) {
cudaIpcMemHandle_t mem_handle;
memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr<uint8_t>(), sizeof(cudaIpcMemHandle_t));
void* p = 0L;
CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) );
results[i] = (int64_t)p;
} else {
results[i] = (int64_t)raw;
}
}
return results;
}
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
}
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
}
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
)
{
// basic checks of inputs
TORCH_CHECK(top_out_halo.is_cuda());
TORCH_CHECK(top_out_tx.is_cuda());
TORCH_CHECK(top_inp_tx.is_cuda());
TORCH_CHECK(top_inp_halo.is_cuda());
TORCH_CHECK(btm_out_halo.is_cuda());
TORCH_CHECK(btm_out_tx.is_cuda());
TORCH_CHECK(btm_inp_tx.is_cuda());
TORCH_CHECK(btm_inp_halo.is_cuda());
TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda());
TORCH_CHECK(waits.is_cuda());
// shapes and strides
int toh_N, toh_C, toh_H, toh_W;
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
int tox_N, tox_C, tox_H, tox_W;
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
int tix_N, tix_C, tix_H, tix_W;
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
int tih_N, tih_C, tih_H, tih_W;
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
TORCH_CHECK(
(toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
(toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
(toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
(toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
int boh_N, boh_C, boh_H, boh_W;
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
int box_N, box_C, box_H, box_W;
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
int bix_N, bix_C, bix_H, bix_W;
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
int bih_N, bih_C, bih_H, bih_W;
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
TORCH_CHECK(
(boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
(boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
(boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
(boh_W == box_W && box_W == bix_W && bix_W == bih_W));
TORCH_CHECK(
(toh_N == boh_N) &&
(toh_C == boh_C) &&
(toh_H == boh_H) &&
(toh_W == boh_W));
int NC=toh_C, NH=toh_H, NW=toh_W;
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
// determine if nhwc
auto is_nhwc = (toh_stride_C == 1) ? true : false;
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
// figure out launch parameters
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
assert(numSM > 0 && numSM <= prop.multiProcessorCount);
auto current_stream = at::cuda::getCurrentCUDAStream();
const int numThreads = 128;
dim3 block(numThreads,1,1);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
if (diagnostics) printf("size(scalar_t) = %ld\n",sizeof(scalar_t));
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
if (diagnostics) printf("waypoint1\n");
int* top_signal_p = top_signal.data_ptr<int>() + 4;
int* btm_signal_p = btm_signal.data_ptr<int>();
int* top_wait_p = waits.data_ptr<int>();
int* btm_wait_p = waits.data_ptr<int>() + 4;
if (diagnostics) printf("waypoint2\n");
// do int4 vector loads if channel count permits
int elem_size_in_bytes = toh_C * sizeof(scalar_t);
int elem_size_in_int4 = (elem_size_in_bytes / 16);
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
// can do int4 transfers
int divisor = toh_C / elem_size_in_int4;
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
NC /= divisor;
if (diagnostics) {
printf("divisor=%d\n",divisor);
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
}
void *kernelArgs[] = {
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<int4,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
} else {
// cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
void *kernelArgs[] = {
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
};
int numBlocksPerSm;
if (is_nhwc) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,true>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,true>, grid, block, kernelArgs, 0, current_stream);
} else {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel<scalar_t,false>, numThreads, 0);
dim3 grid(numSM*numBlocksPerSm,1,1);
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<scalar_t,false>, grid, block, kernelArgs, 0, current_stream);
}
}
} );
}
} } }
/**
* Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#ifndef _peer_memory_h_
#define _peer_memory_h_
namespace apex { namespace contrib { namespace peer_memory {
int64_t allocate_raw(int64_t size);
void free_raw(int64_t raw);
void zero(int64_t raw, int64_t size);
at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc,
int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory
at::Tensor waits // top and btm signals for this rank
);
} } }
#endif
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
import torch
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory_cuda as pm
# How to run:
# torchrun --nproc_per_node <num-GPU> <this-python-prog>
# <num-GPU> must be a power of 2 greater than 1.
# Output of this function is used as ground truth in module tests.
def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):
if explicit_nhwc:
if H_split:
_, Hp, _, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,half_halo:2*half_halo,:,:]
top_inp_halo = y[:,:half_halo,:,:]
btm_out_halo = y[:,H:H+half_halo,:,:]
btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:]
else:
_, _, Wp, _ = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,W:W+half_halo,:]
btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:]
else:
if H_split:
_, _, Hp, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,H:H+half_halo,:]
btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:]
else:
_, _, _, Wp = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,:,half_halo:2*half_halo]
top_inp_halo = y[:,:,:,:half_halo]
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format)
btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format)
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
top_inp_halo.copy_(btm_inp_halos[top_rank])
btm_inp_halo.copy_(top_inp_halos[btm_rank])
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
if H_split:
y = torch.randn([1,H+2*half_halo,W,C], dtype=dtype, device='cuda')
ym = y[:,half_halo:H+half_halo,:,:]
else:
y = torch.randn([1,H,W+2*half_halo,C], dtype=dtype, device='cuda')
ym = y[:,:,half_halo:W+half_halo,:]
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
if H_split:
y = torch.randn([1,C,H+2*half_halo,W], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,half_halo:H+half_halo,:]
else:
y = torch.randn([1,C,H,W+2*half_halo], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo]
y3 = y.clone()
list_y = []
for step in range(num_steps):
halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0:
if memory_format == 1:
memory_format_str = "explicit_nhwc"
elif memory_format == 2:
memory_format_str = "native nhwc"
elif memory_format == 3:
memory_format_str = "nchw"
else:
memory_format_str = "???"
if is_equal:
print("SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
else:
print("FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
# peer memory flag sync relies on there being at least one barrier per step
torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main():
# for this trivial example peer_rank == rank and peer_group_size == world_size
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__":
main()
import torch
from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.half_halo = half_halo
def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False):
channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
)
import torch
import numpy as np
import peer_memory_cuda as pm
class PeerMemoryPool(object):
def __init__(self, rank, world_size, peer_group_size, static_size, dynamic_size):
self.peer_group = rank // peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_group_size = peer_group_size
self.alignment = 256
self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory
self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl
raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
self.peer_raw = pm.get_raw_peers(peer_raw_ipcs, self.peer_rank, self.raw)
self.static_offset = 0
self.dynamic_offset = 0
def __del__(self):
pm.free_raw(self.raw)
def reset(self):
self.dynamic_offset = 0
def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
nels = np.prod(shape)
if dtype == torch.float16:
elem_size = 2
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.float32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.int32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]
else:
assert(False), "dtype %s not supported" % (str(dtype))
......@@ -652,6 +652,34 @@ if "--fast_bottleneck" in sys.argv:
)
)
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
raise_if_cuda_home_none("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
sources=[
"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
"apex/contrib/csrc/peer_memory/peer_memory.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
sources=[
"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
)
if "--fused_conv_bias_relu" in sys.argv:
sys.argv.remove("--fused_conv_bias_relu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment