Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
3dd35b45
Commit
3dd35b45
authored
Jul 03, 2025
by
Max Rietmann
Browse files
Fixed compile errors for ChannelsLast C++ code, unfortunately also format-on-save
parent
e1338191
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
265 additions
and
273 deletions
+265
-273
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+265
-273
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
3dd35b45
...
...
@@ -51,49 +51,53 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)
+
((b)
-1))/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)
- 1)) /
(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
}}
} \
}
#endif
#include <iostream>
#include <chrono>
#include <string>
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
~
ScopeTimer
()
{
~
ScopeTimer
()
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
}
private:
private:
std
::
string
label_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
...
...
@@ -108,14 +112,9 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
...
...
@@ -125,14 +124,15 @@ __launch_bounds__(BDIM_X)
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
...
...
@@ -156,7 +156,7 @@ __launch_bounds__(BDIM_X)
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// First pass: find qdotk_max
...
...
@@ -201,7 +201,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
// Third pass: accumulate gradients for k and v
...
...
@@ -227,16 +228,11 @@ __launch_bounds__(BDIM_X)
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
...
...
@@ -259,16 +255,16 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto
dy_type
=
dy
.
dtype
();
// exract memory format
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
_l
ast
);
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
_l
ast
);
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
_l
ast
);
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
_l
ast
);
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
L
ast
);
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
L
ast
);
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
L
ast
);
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels
L
ast
);
// convert to channels-last
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
// cudaDeviceSynchronize();
// delete permute_timer;
...
...
@@ -284,8 +280,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
...
...
@@ -294,10 +290,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
...
...
@@ -330,20 +324,20 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
dydq
=
dydq
.
to
(
qy_type
);
// permute back to original layout
if
(
!
kx_is_channels_last
){
dydk
=
dydk
.
to
(
kx_type
,
at
::
MemoryFormat
::
Contiguous
);
if
(
!
kx_is_channels_last
)
{
dydk
=
dydk
.
to
(
kx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydk
=
dydk
.
to
(
kx_type
);
}
if
(
!
vx_is_channels_last
){
dydv
=
dydv
.
to
(
vx_type
,
at
::
MemoryFormat
::
Contiguous
);
if
(
!
vx_is_channels_last
)
{
dydv
=
dydv
.
to
(
vx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydv
=
dydv
.
to
(
vx_type
);
}
if
(
!
qy_is_channels_last
)
{
dydq
=
dydq
.
to
(
qy_type
,
at
::
MemoryFormat
::
Contiguous
);
if
(
!
qy_is_channels_last
)
{
dydq
=
dydq
.
to
(
qy_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydq
=
dydq
.
to
(
qy_type
)
dydq
=
dydq
.
to
(
qy_type
)
;
}
// printf("dydk strides: [");
...
...
@@ -355,6 +349,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment