Commit 4096e64b authored by Max Rietmann's avatar Max Rietmann
Browse files

Optimizations for backward kernel: moved qy to shared, memory layout

Detect memory layout (B,C,H,W) (stride for C should be 1, if not, fix it)

This ensures that the backwards kernel is fast
parent b62c420f
...@@ -247,7 +247,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -247,7 +247,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0") grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
time_layer_setup_end.record() time_layer_setup_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Layer setup: {time_layer_setup_start.elapsed_time(time_layer_setup_end)} ms") # print(f"Layer setup: {time_layer_setup_start.elapsed_time(time_layer_setup_end)} ms")
# random weights # random weights
with torch.no_grad(): with torch.no_grad():
...@@ -268,7 +268,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -268,7 +268,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu) out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record() time_forward_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Forward execution: {time_forward_start.elapsed_time(time_forward_end)} ms")
elapsed_time = time_forward_start.elapsed_time(time_forward_end)
assert elapsed_time < 150, "Forward pass took much too long, there must be a performance regression!"
# sync weights: # sync weights:
with torch.no_grad(): with torch.no_grad():
...@@ -279,11 +281,11 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -279,11 +281,11 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
att_gpu.k_bias.copy_(att_gpu.k_bias) att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias) att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device, memory_format=torch.channels_last) q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
q_gpu.requires_grad = True q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device, memory_format=torch.channels_last) k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
k_gpu.requires_grad = True k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device, memory_format=torch.channels_last) v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
v_gpu.requires_grad = True v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu) out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
...@@ -291,18 +293,18 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -291,18 +293,18 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
time_backward_start = torch.cuda.Event(enable_timing=True) time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True) time_backward_end = torch.cuda.Event(enable_timing=True)
print("q_gpu_stride=",q_gpu.stride())
for i in range(2): for i in range(2):
# warmup # warmup
out_gpu.backward(out_grad, retain_graph=True) out_gpu.backward(out_grad, retain_graph=True)
print("out_grad_stride=",out_grad.stride()) # print("out_grad_stride=",out_grad.stride())
time_backward_start.record() time_backward_start.record()
out_gpu.backward(out_grad) out_gpu.backward(out_grad)
time_backward_end.record() time_backward_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Backward execution: {time_backward_start.elapsed_time(time_backward_end)} ms") # print(f"Backward execution: {time_backward_start.elapsed_time(time_backward_end)} ms")
elapsed_time = time_backward_start.elapsed_time(time_backward_end)
assert elapsed_time < 400, "Backward pass took much too long, there must be a performance regression!"
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -1000,7 +1000,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1000,7 +1000,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto version = HOWO_WARP_VERSION; auto version = HOWO_WARP_VERSION;
// auto version = OLD_VERSION; // auto version = OLD_VERSION;
if (version == OLD_VERSION) { if (version == OLD_VERSION) {
printf("old version\n"); // printf("old version\n");
torch::Tensor dydk = torch::zeros_like(qy); torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy); torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy); torch::Tensor dydq = torch::zeros_like(qy);
...@@ -1062,34 +1062,56 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1062,34 +1062,56 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} else if (version == HOWO_WARP_VERSION) { } else if (version == HOWO_WARP_VERSION) {
ScopeTimer timer("Full s2_attention_bwd_dkvq_kernel_mbT"); // ScopeTimer timer("Full s2_attention_bwd_dkvq_kernel_mbT");
// Time this function via C++ // Time this function via C++
time_t start_time, end_time;
start_time = clock(); auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel] // Transpose to [batch, ho, wo, channel]
// nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
// auto kxP = kx.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto vxP = vx.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto qyP = qy.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize(); // cudaDeviceSynchronize();
// delete permute_timer; // delete permute_timer;
// nvtxRangePop(); nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydkP = torch::zeros_like(qy); auto dydkP = torch::zeros_like(qyP);
auto dydvP = torch::zeros_like(qy); auto dydvP = torch::zeros_like(qyP);
auto dydqP = torch::zeros_like(qy); auto dydqP = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP // print strdie of dydkP, dydvP, dydqP
printf("dydkP strides: ");
for(auto& stride_i :dydkP.strides()) {
printf("%ld ", stride_i);
}
printf("\n");
cudaDeviceSynchronize();
nvtxRangePop(); nvtxRangePop();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
...@@ -1108,10 +1130,10 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1108,10 +1130,10 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
s2_attention_bwd_dkvq_kernel_mbT<THREADS><<< s2_attention_bwd_dkvq_kernel_mbT<THREADS><<<
grid, block, shared_size, stream>>>( grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
...@@ -1125,7 +1147,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1125,7 +1147,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds); // printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
......
...@@ -319,9 +319,12 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -319,9 +319,12 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for // transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous(); torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous();
torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous(); torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous();
torch::Tensor qyP = qy.permute({0,2,3,1}).contiguous(); torch::Tensor qyP = qy.permute({0, 2, 3, 1}).contiguous();
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy); torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE); dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment