Commit 51200bda authored by Mauro Bisson's avatar Mauro Bisson Committed by Thorsten Kurth
Browse files

Removed stale comments.

parent 07fa44d6
...@@ -1007,9 +1007,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -1007,9 +1007,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
if (!dy_channel_first) { dyP = permute_4D_floatT_to0231(dy, stream); } if (!dy_channel_first) { dyP = permute_4D_floatT_to0231(dy, stream); }
torch::Tensor dkxP = torch::zeros_like(kxP); // dkx: [batch][hi][wi][chan] torch::Tensor dkxP = torch::zeros_like(kxP);
torch::Tensor dvxP = torch::zeros_like(vxP); // dvx: [batch][hi][wi][chan] torch::Tensor dvxP = torch::zeros_like(vxP);
torch::Tensor dqyP = torch::zeros_like(qyP); // dqy: [batch][ho][wo][chan] torch::Tensor dqyP = torch::zeros_like(qyP);
s2_attn_bwd_dispatch(batch_size, s2_attn_bwd_dispatch(batch_size,
uo_num_channels, uo_num_channels,
...@@ -1023,22 +1023,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -1023,22 +1023,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
dkxP, dvxP, dqyP, // out tensors dkxP, dvxP, dqyP, // out tensors
stream); stream);
torch::Tensor dkx = dkxP; // dkx: [batch][hi][wi][chan] torch::Tensor dkx = dkxP;
torch::Tensor dvx = dvxP; // dvx: [batch][hi][wi][chan] torch::Tensor dvx = dvxP;
torch::Tensor dqy = dqyP; // dqy: [batch][ho][wo][chan] torch::Tensor dqy = dqyP;
if (!kx_channel_first) { dkx = permute_4D_floatT_to0312(dkxP, stream); } if (!kx_channel_first) { dkx = permute_4D_floatT_to0312(dkxP, stream); }
if (!vx_channel_first) { dvx = permute_4D_floatT_to0312(dvxP, stream); } if (!vx_channel_first) { dvx = permute_4D_floatT_to0312(dvxP, stream); }
if (!qy_channel_first) { dqy = permute_4D_floatT_to0312(dqyP, stream); } if (!qy_channel_first) { dqy = permute_4D_floatT_to0312(dqyP, stream); }
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dkx, dvx, dqy); return std::make_tuple(dkx, dvx, dqy);
#endif #endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment