Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
68e7d0fa
Commit
68e7d0fa
authored
Jun 16, 2025
by
Max Rietmann
Browse files
Clang format
parent
cb79c766
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
78 deletions
+69
-78
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+69
-78
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
68e7d0fa
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -51,28 +51,32 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)
+
((b)
-1))/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)
- 1)) /
(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#include <iostream>
#include <chrono>
#include <string>
class
ScopeTimer
{
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
~
ScopeTimer
()
{
~
ScopeTimer
()
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
...
...
@@ -83,20 +87,19 @@ private:
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
...
...
@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
...
...
@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
...
...
@@ -176,7 +175,7 @@ __launch_bounds__(BDIM_X)
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
...
...
@@ -190,7 +189,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
// Third pass: accumulate gradients for k and v
...
...
@@ -216,16 +216,11 @@ __launch_bounds__(BDIM_X)
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
...
...
@@ -246,7 +241,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
...
@@ -289,8 +284,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
...
...
@@ -299,20 +294,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
...
...
@@ -328,15 +321,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
printf
(
"s2_attention_bwd_kernel execution time: %f ms
\n
"
,
milliseconds
);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
...
...
@@ -347,6 +340,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment