Unverified Commit 1eed5673 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #85 from azrael417/tkurth/attention_bwd_layout_stuff

using torch tools to change layout in bd pass
parents 49a61eee 191ba149
...@@ -51,7 +51,7 @@ ...@@ -51,7 +51,7 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
...@@ -233,44 +233,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -233,44 +233,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel] // Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // extract dtype
auto kxP = at::Tensor(); auto kx_type = kx.dtype();
if (!k_channel_first) { auto vx_type = vx.dtype();
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); auto qy_type = qy.dtype();
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); auto dy_type = dy.dtype();
} else {
kxP = kx; // exract memory format
} auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto vxP = at::Tensor(); auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
if (!v_channel_first) { auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { // convert to channels-last
vxP = vx; auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
} auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto qyP = at::Tensor(); auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
if (!q_channel_first) { auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize(); // cudaDeviceSynchronize();
// delete permute_timer; // delete permute_timer;
nvtxRangePop(); nvtxRangePop();
...@@ -312,10 +296,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -312,10 +296,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 50.724865 ms // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5], // printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
// s2_attention_bwd_kernel execution time: 11.679744 ms
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
...@@ -324,11 +306,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -324,11 +306,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
// Permute outputs back to memory layout given by input. if input had channels // Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch, // first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo] // channel, ho, wo]
if (!k_channel_first) dydk = dydk.contiguous();
if (!v_channel_first) dydv = dydv.contiguous();
if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:["); // convert back to original dtype
dydk = dydk.to(kx_type);
dydv = dydv.to(vx_type);
dydq = dydq.to(qy_type);
// permute back to original layout
if (!kx_is_channels_last) {
dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
} else {
dydk = dydk.to(kx_type);
}
if (!vx_is_channels_last) {
dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
} else {
dydv = dydv.to(vx_type);
}
if (!qy_is_channels_last) {
dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
} else {
dydq = dydq.to(qy_type);
}
// printf("dydk strides: [");
// for(auto& stride : dydk.strides()) { // for(auto& stride : dydk.strides()) {
// printf("%ld,", stride); // printf("%ld,", stride);
// } // }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment