Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
1eed5673
Unverified
Commit
1eed5673
authored
Jul 04, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 04, 2025
Browse files
Merge pull request #85 from azrael417/tkurth/attention_bwd_layout_stuff
using torch tools to change layout in bd pass
parents
49a61eee
191ba149
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
43 deletions
+44
-43
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+44
-43
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
1eed5673
...
@@ -51,7 +51,7 @@
...
@@ -51,7 +51,7 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call) \
...
@@ -233,44 +233,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -233,44 +233,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
// Transpose to [batch, ho, wo, channel]
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// extract dtype
auto
kxP
=
at
::
Tensor
();
auto
kx_type
=
kx
.
dtype
();
if
(
!
k_channel_first
)
{
auto
vx_type
=
vx
.
dtype
();
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
auto
qy_type
=
qy
.
dtype
();
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
auto
dy_type
=
dy
.
dtype
();
}
else
{
kxP
=
kx
;
// exract memory format
}
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
vxP
=
at
::
Tensor
();
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
if
(
!
v_channel_first
)
{
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
// convert to channels-last
vxP
=
vx
;
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
}
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
at
::
Tensor
();
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
if
(
!
q_channel_first
)
{
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
qyP
=
qy
;
}
auto
dyP
=
at
::
Tensor
();
if
(
!
dy_channel_first
)
{
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
dyP
=
dy
;
}
// cudaDeviceSynchronize();
// cudaDeviceSynchronize();
// delete permute_timer;
// delete permute_timer;
nvtxRangePop
();
nvtxRangePop
();
...
@@ -312,10 +296,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -312,10 +296,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 50.724865 ms
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
// s2_attention_bwd_kernel execution time: 11.679744 ms
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
...
@@ -324,11 +306,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -324,11 +306,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
// Permute outputs back to memory layout given by input. if input had channels
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// convert back to original dtype
dydk
=
dydk
.
to
(
kx_type
);
dydv
=
dydv
.
to
(
vx_type
);
dydq
=
dydq
.
to
(
qy_type
);
// permute back to original layout
if
(
!
kx_is_channels_last
)
{
dydk
=
dydk
.
to
(
kx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydk
=
dydk
.
to
(
kx_type
);
}
if
(
!
vx_is_channels_last
)
{
dydv
=
dydv
.
to
(
vx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydv
=
dydv
.
to
(
vx_type
);
}
if
(
!
qy_is_channels_last
)
{
dydq
=
dydq
.
to
(
qy_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydq
=
dydq
.
to
(
qy_type
);
}
// printf("dydk strides: [");
// for(auto& stride : dydk.strides()) {
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// printf("%ld,", stride);
// }
// }
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment