Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
51200bda
Commit
51200bda
authored
Jul 11, 2025
by
Mauro Bisson
Committed by
Thorsten Kurth
Jul 15, 2025
Browse files
Removed stale comments.
parent
07fa44d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
14 deletions
+6
-14
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+6
-14
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
51200bda
...
...
@@ -1007,9 +1007,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
if
(
!
dy_channel_first
)
{
dyP
=
permute_4D_floatT_to0231
(
dy
,
stream
);
}
torch
::
Tensor
dkxP
=
torch
::
zeros_like
(
kxP
);
// dkx: [batch][hi][wi][chan]
torch
::
Tensor
dvxP
=
torch
::
zeros_like
(
vxP
);
// dvx: [batch][hi][wi][chan]
torch
::
Tensor
dqyP
=
torch
::
zeros_like
(
qyP
);
// dqy: [batch][ho][wo][chan]
torch
::
Tensor
dkxP
=
torch
::
zeros_like
(
kxP
);
torch
::
Tensor
dvxP
=
torch
::
zeros_like
(
vxP
);
torch
::
Tensor
dqyP
=
torch
::
zeros_like
(
qyP
);
s2_attn_bwd_dispatch
(
batch_size
,
uo_num_channels
,
...
...
@@ -1023,22 +1023,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
dkxP
,
dvxP
,
dqyP
,
// out tensors
stream
);
torch
::
Tensor
dkx
=
dkxP
;
// dkx: [batch][hi][wi][chan]
torch
::
Tensor
dvx
=
dvxP
;
// dvx: [batch][hi][wi][chan]
torch
::
Tensor
dqy
=
dqyP
;
// dqy: [batch][ho][wo][chan]
torch
::
Tensor
dkx
=
dkxP
;
torch
::
Tensor
dvx
=
dvxP
;
torch
::
Tensor
dqy
=
dqyP
;
if
(
!
kx_channel_first
)
{
dkx
=
permute_4D_floatT_to0312
(
dkxP
,
stream
);
}
if
(
!
vx_channel_first
)
{
dvx
=
permute_4D_floatT_to0312
(
dvxP
,
stream
);
}
if
(
!
qy_channel_first
)
{
dqy
=
permute_4D_floatT_to0312
(
dqyP
,
stream
);
}
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dkx
,
dvx
,
dqy
);
#endif
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment