Unverified Commit 9959a7a6 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

removing some comments and nvtx annotations (#88)

parent 744d2269
...@@ -233,10 +233,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -233,10 +233,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// extract dtype // extract dtype
auto kx_type = kx.dtype(); auto kx_type = kx.dtype();
auto vx_type = vx.dtype(); auto vx_type = vx.dtype();
...@@ -255,16 +251,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -255,16 +251,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast); auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast); auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// cudaDeviceSynchronize(); // create output arrays
// delete permute_timer;
nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP); auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP); auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP); auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
...@@ -297,7 +287,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -297,7 +287,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
...@@ -329,13 +318,5 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -329,13 +318,5 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
dydq = dydq.to(qy_type); dydq = dydq.to(qy_type);
} }
// printf("dydk strides: [");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment