Commit 88914a50 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add halo correction kernel for bprop

parent 705aa35d
...@@ -268,17 +268,6 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -268,17 +268,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo) spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1: if spatial_method == 1:
# overlap mid convolution with halo transfer # overlap mid convolution with halo transfer
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) stream2.wait_stream(stream1)
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -291,6 +280,17 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -291,6 +280,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10) inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3: elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3" assert(False), "spatial_method must be 1, 2 or 3"
...@@ -329,13 +329,6 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -329,13 +329,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# to wait for out2_mask to finish, but itself has to finish before # to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch. # the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels. # At least we can overlap the two halo correction kernels.
if spatial_group_rank > 0:
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
...@@ -344,9 +337,16 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -344,9 +337,16 @@ class SpatialBottleneckFunction(torch.autograd.Function):
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone()) btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2) btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0: if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1) stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2) torch.cuda.current_stream().wait_stream(stream2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass # save halos for backward pass
...@@ -365,6 +365,8 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -365,6 +365,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
ctx.spatial_group_rank = spatial_group_rank ctx.spatial_group_rank = spatial_group_rank
ctx.spatial_halo_exchanger = spatial_halo_exchanger ctx.spatial_halo_exchanger = spatial_halo_exchanger
ctx.spatial_method = spatial_method ctx.spatial_method = spatial_method
ctx.thresholdTop = thresholdTop
ctx.thresholdBottom = thresholdBottom
ctx.stream1 = stream1 ctx.stream1 = stream1
ctx.stream2 = stream2 ctx.stream2 = stream2
ctx.stream3 = stream3 ctx.stream3 = stream3
...@@ -414,50 +416,55 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -414,50 +416,55 @@ class SpatialBottleneckFunction(torch.autograd.Function):
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream1):
top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:]) top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:])
# copy halos to send buffer # copy halos to send buffer
if ctx.spatial_group_rank < ctx.spatial_group_size-1: if ctx.spatial_method == 1 or ctx.spatial_method == 2:
ctx.stream2.wait_stream(ctx.stream1) # 1 -> halo recompute approach
with torch.cuda.stream(ctx.stream2): # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop)
if ctx.explicit_nhwc: if ctx.spatial_group_rank < ctx.spatial_group_size-1:
btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) ctx.stream2.wait_stream(ctx.stream1)
btm_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) with torch.cuda.stream(ctx.stream2):
btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) if ctx.explicit_nhwc:
btm_fat_halo[:,2:,:,:].copy_(btm_halo) btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:])
btm_relu_halo[:,2:,:,:].zero_() btm_fat_halo[:,2:,:,:].copy_(btm_halo)
else: btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:])
btm_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) btm_fat_relu_halo[:,2:,:,:].zero_()
btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:]) else:
btm_fat_halo[:,:,2:,:].copy_(btm_halo) btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
btm_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:]) btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:])
btm_relu_halo[:,:,2:,:].zero_() btm_fat_halo[:,:,2:,:].copy_(btm_halo)
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_relu_halo) btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:])
btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] btm_fat_relu_halo[:,:,2:,:].zero_()
else: btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo)
btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:] if ctx.explicit_nhwc:
if ctx.spatial_group_rank > 0: btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:]
with torch.cuda.stream(ctx.stream1): else:
if ctx.explicit_nhwc: btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:]
top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) if ctx.spatial_group_rank > 0:
top_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) with torch.cuda.stream(ctx.stream1):
top_fat_halo[:,:1,:,:].copy_(top_halo) if ctx.explicit_nhwc:
top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo[:,:1,:,:].zero_() top_fat_halo[:,:1,:,:].copy_(top_halo)
top_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:])
else: top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device)
top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) top_fat_relu_halo[:,:1,:,:].zero_()
top_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:])
top_fat_halo[:,:,:1,:].copy_(top_halo) else:
top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:]) top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
top_relu_halo[:,:,:1,:].zero_() top_fat_halo[:,:,:1,:].copy_(top_halo)
top_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:]) top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:])
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_relu_halo) top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device)
if ctx.explicit_nhwc: top_fat_relu_halo[:,:,:1,:].zero_()
top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:])
else: top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo)
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:] if ctx.explicit_nhwc:
inc.add_delay(10) top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:]
else:
top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:]
inc.add_delay(10)
elif ctx.spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
with torch.cuda.stream(wgrad2_stream): with torch.cuda.stream(wgrad2_stream):
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -466,7 +473,10 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -466,7 +473,10 @@ class SpatialBottleneckFunction(torch.autograd.Function):
wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) wgrad2 = fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
# compute grad_out1 for internal cells # compute grad_out1 for internal cells
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2:
grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2)
elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3:
grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom)
# apply halo cells to grad_out1 # apply halo cells to grad_out1
if ctx.spatial_group_size > 1: if ctx.spatial_group_size > 1:
...@@ -474,20 +484,51 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -474,20 +484,51 @@ class SpatialBottleneckFunction(torch.autograd.Function):
z = t_list[4] z = t_list[4]
relu1 = t_list[12] relu1 = t_list[12]
#print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape))))
if ctx.spatial_group_rank > 0: if ctx.spatial_method == 1 or ctx.spatial_method == 2:
torch.cuda.current_stream().wait_stream(ctx.stream1) if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc: torch.cuda.current_stream().wait_stream(ctx.stream2)
grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) if ctx.explicit_nhwc:
else: grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo)
grad_out1[:,:,:1,:].copy_(top_grad_out1_halo) else:
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1: #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
torch.cuda.current_stream().wait_stream(ctx.stream2) if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc: torch.cuda.current_stream().wait_stream(ctx.stream1)
grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) if ctx.explicit_nhwc:
else: grad_out1[:,:1,:,:].copy_(top_grad_out1_halo)
grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo) else:
#print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) grad_out1[:,:,:1,:].copy_(top_grad_out1_halo)
#print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape))))
elif ctx.spatial_method == 3:
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
if ctx.explicit_nhwc:
btm_relu_halo = relu1[:,Hs-1:,:,:].clone()
btm_grad_out1 = grad_out1[:,Hs-1:,:,:]
else:
btm_relu_halo = relu1[:,:,Hs-1:,:].clone()
btm_grad_out1 = grad_out1[:,:,Hs-1:,:]
w1by3 = w[:,:1,:,:].clone()
ctx.stream1.wait_stream(ctx.stream2) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0:
if ctx.explicit_nhwc:
top_relu_halo = relu1[:,:1,:,:].clone()
top_grad_out1 = grad_out1[:,:1,:,:]
else:
top_relu_halo = relu1[:,:,:1,:].clone()
top_grad_out1 = grad_out1[:,:,:1,:]
w1by3 = w[:,2:,:,:].clone()
ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1):
top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone())
top_grad_out1.copy_(top_grad_out1_halo)
if ctx.spatial_group_rank < ctx.spatial_group_size-1:
torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish
if ctx.spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(ctx.stream1)
fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2) fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1, wgrad2)
torch.cuda.current_stream().wait_stream(wgrad2_stream) torch.cuda.current_stream().wait_stream(wgrad2_stream)
......
...@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3 ...@@ -161,7 +161,7 @@ def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp3
spatial_group_rank = rank spatial_group_rank = rank
spatial_communicator = None spatial_communicator = None
spatial_halo_exchanger = halex spatial_halo_exchanger = halex
spatial_method = 2 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x spatial_method = 3 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method) spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args) spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
......
...@@ -466,6 +466,14 @@ enum { ...@@ -466,6 +466,14 @@ enum {
RELU_TENSOR, RELU_TENSOR,
AFTER_DCONV_TENSOR, AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR, AFTER_DRELU_TENSOR,
DGRAD_INPUT_TENSOR,
DGRAD_OPTIONAL_TENSOR,
DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
}; };
using dconv_descriptors = std::tuple<cudnn_frontend::Tensor, using dconv_descriptors = std::tuple<cudnn_frontend::Tensor,
...@@ -474,6 +482,8 @@ using dconv_descriptors = std::tuple<cudnn_frontend::Tensor, ...@@ -474,6 +482,8 @@ using dconv_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor, cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>; cudnn_frontend::Tensor>;
dconv_descriptors dconv_descriptors
...@@ -552,7 +562,181 @@ create_dconv_descriptors(int64_t* x_dim_padded, ...@@ -552,7 +562,181 @@ create_dconv_descriptors(int64_t* x_dim_padded,
.setId('B') // after drelu .setId('B') // after drelu
.setAlignment(16) .setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT) .setDataType(CUDNN_DATA_FLOAT)
.build()); .build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_mask_descriptors
create_dconv_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
} }
// create a cache for plan // create a cache for plan
...@@ -1520,6 +1704,371 @@ run_dconv_drelu_dscale(int64_t* x_dim_padded, ...@@ -1520,6 +1704,371 @@ run_dconv_drelu_dscale(int64_t* x_dim_padded,
} }
} }
void
run_dconv_add_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors = create_dconv_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void void
run_dconv(int64_t* x_dim_padded, run_dconv(int64_t* x_dim_padded,
int64_t* pad, int64_t* pad,
...@@ -2708,6 +3257,7 @@ struct bottleneck_backward_state { ...@@ -2708,6 +3257,7 @@ struct bottleneck_backward_state {
int64_t filterdimA3[4]; int64_t filterdimA3[4];
int64_t filterdimA4[4]; int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3 int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int64_t threshdim[4];
int axis[4]; int axis[4];
...@@ -2734,6 +3284,7 @@ struct bottleneck_backward_state { ...@@ -2734,6 +3284,7 @@ struct bottleneck_backward_state {
int64_t outdim2[4]; int64_t outdim2[4];
int64_t outdim3[4]; int64_t outdim3[4];
int64_t outdim1h[4]; int64_t outdim1h[4];
int64_t outdim1hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) { void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions // setup dimensions
...@@ -2743,6 +3294,7 @@ struct bottleneck_backward_state { ...@@ -2743,6 +3294,7 @@ struct bottleneck_backward_state {
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w // All dim calculation after this order of n,c,h,w
if (explicit_nhwc) { if (explicit_nhwc) {
...@@ -2841,6 +3393,7 @@ struct bottleneck_backward_state { ...@@ -2841,6 +3393,7 @@ struct bottleneck_backward_state {
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) { if (explicit_nhwc) {
axis[0] = 0; axis[0] = 0;
...@@ -2854,6 +3407,7 @@ struct bottleneck_backward_state { ...@@ -2854,6 +3407,7 @@ struct bottleneck_backward_state {
outdim2[dim] = outdimA2[axis[dim]]; outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]]; outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]]; outdim1h[dim] = outdimA1h[axis[dim]];
outdim1hh[dim] = outdimA1hh[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]]; filterdim2hh[dim] = filterdimA2hh[axis[dim]];
} }
} }
...@@ -2967,6 +3521,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2967,6 +3521,7 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad // fused dgrad
//printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);
run_dconv_drelu_dscale(backward_state.outdimA1, run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1, backward_state.padA1,
backward_state.convstrideA, backward_state.convstrideA,
...@@ -2983,6 +3538,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std ...@@ -2983,6 +3538,88 @@ at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std
return grad_out1; return grad_out1;
} }
at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale_mask(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
backward_state.threshdim,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2);
return grad_out1;
}
// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C]
at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, at::Tensor w1by3, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
//at::Half* w = inputs[2].data_ptr<at::Half>(); // use w1by3 instead, which is a sliced version of inputs[2]
at::Half* w = w1by3.data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_add_drelu_dscale(backward_state.outdimA1hh,
backward_state.padA2, // 0,1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // C,1,3,C
backward_state.outdimA2hh,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h,
pdy1h);
return grad_out1_halo;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] // perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
...@@ -3303,7 +3940,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -3303,7 +3940,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward"); m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment