Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
3d06f4da
Commit
3d06f4da
authored
Jun 11, 2025
by
Max Rietmann
Browse files
Removed unnecessary code in fwd and bwd kernels.
Also: Made fwd kernel use modified memory layout with standard shape
parent
6512d042
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
176 deletions
+68
-176
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+10
-26
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+58
-150
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
3d06f4da
...
...
@@ -83,18 +83,6 @@ private:
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
__device__
static
float
atomicMax
(
float
*
address
,
float
val
)
{
int
*
address_as_i
=
(
int
*
)
address
;
int
old
=
*
address_as_i
,
assumed
;
do
{
assumed
=
old
;
old
=
::
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
::
fmaxf
(
val
,
__int_as_float
(
assumed
))));
}
while
(
assumed
!=
old
);
return
__int_as_float
(
old
);
}
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
...
...
@@ -105,7 +93,7 @@ static __device__ float __warp_sum(float val) {
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
...
...
@@ -303,9 +291,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePop
();
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
auto
dydk
P
=
torch
::
zeros_like
(
qyP
);
auto
dydv
P
=
torch
::
zeros_like
(
qyP
);
auto
dydq
P
=
torch
::
zeros_like
(
qyP
);
auto
dydk
=
torch
::
zeros_like
(
qyP
);
auto
dydv
=
torch
::
zeros_like
(
qyP
);
auto
dydq
=
torch
::
zeros_like
(
qyP
);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop
();
...
...
@@ -329,9 +317,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
P
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
P
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
P
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
...
...
@@ -351,13 +339,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
at
::
Tensor
dydk
,
dydv
,
dydq
;
if
(
!
k_channel_first
)
dydk
=
dydkP
.
contiguous
();
else
dydk
=
dydkP
;
if
(
!
v_channel_first
)
dydv
=
dydvP
.
contiguous
();
else
dydv
=
dydvP
;
if
(
!
q_channel_first
)
dydq
=
dydqP
.
contiguous
();
else
dydq
=
dydqP
;
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
3d06f4da
...
...
@@ -65,137 +65,6 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
exit(EXIT_FAILURE); \
}}
__device__
static
float
atomicMax
(
float
*
address
,
float
val
)
{
int
*
address_as_i
=
(
int
*
)
address
;
int
old
=
*
address_as_i
,
assumed
;
do
{
assumed
=
old
;
old
=
::
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
::
fmaxf
(
val
,
__int_as_float
(
assumed
))));
}
while
(
assumed
!=
old
);
return
__int_as_float
(
old
);
}
__global__
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
// shared memory
extern
__shared__
float
sharedMem
[];
float
*
sh_alpha_sum
=
(
float
*
)
&
sharedMem
;
float
*
sh_qdotk_max
=
(
float
*
)
&
sharedMem
[
1
];
float
*
sh_qy_ho_wo
=
(
float
*
)
&
sharedMem
[
2
];
if
(
threadIdx
.
x
==
0
)
{
sh_qdotk_max
[
0
]
=
std
::
numeric_limits
<
float
>::
lowest
();
sh_alpha_sum
[
0
]
=
0.0
;
}
__syncthreads
();
int
ho
=
blockIdx
.
x
;
int
wo
=
blockIdx
.
y
;
int
batch_b
=
blockIdx
.
z
;
// load qy channels into shared memory
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
sh_qy_ho_wo
[
channel_idx
]
=
qy
[
batch_b
][
channel_idx
][
ho
][
wo
];
y
[
batch_b
][
channel_idx
][
ho
][
wo
]
=
0.0
;
}
__syncthreads
();
int
psi_offset
=
psi_row_offset
[
ho
];
int
psi_nnz_ho
=
psi_row_offset
[
ho
+
1
]
-
psi_offset
;
float
qdotk_max
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
qdotk_max
=
std
::
max
(
qdotk_max
,
qdotk
);
}
// collect thread-local qdotk max
atomicMax
(
&
sh_qdotk_max
[
0
],
qdotk_max
);
__syncthreads
();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max
=
sh_qdotk_max
[
0
];
float
alpha_sum
=
0.0
f
;
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
float
alpha_inz
=
0.0
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
<
psi_nnz_ho
)
{
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// thread-local sum alpha
alpha_sum
+=
alpha_inz
;
// multiply alpha, vx, and quadrature weights
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
y
[
batch_b
][
channel_idx
][
ho
][
wo
],
alpha_inz
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
]);
}
}
}
// collect all alpha_sum across threads
atomicAdd
(
&
sh_alpha_sum
[
0
],
alpha_sum
);
__syncthreads
();
// rebroadcast sum to all threads
alpha_sum
=
sh_alpha_sum
[
0
];
// divide output by alpha_sum
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
y
[
batch_b
][
channel_idx
][
ho
][
wo
]
/=
alpha_sum
;
}
}
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
...
...
@@ -204,11 +73,24 @@ static __device__ float __warp_sum(float val) {
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
return
sum
;
}
// one warp per (ho,wo)
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
_mbT
(
int
num_channels
,
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
...
...
@@ -263,10 +145,10 @@ __launch_bounds__(BDIM_X)
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
ho
][
wo
]
[
chan
]
*
kx
[
batchId
][
hi
][
wip
]
[
chan
]
;
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum
(
qdotk
);
qdotk
=
__warp_sum
_cub
(
qdotk
);
float
qdotk_max_tmp
;
float
alpha
;
...
...
@@ -279,7 +161,7 @@ __launch_bounds__(BDIM_X)
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
hi
][
wip
]
[
chan
]
*
alpha
;
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
}
qdotk_max
=
qdotk_max_tmp
;
}
...
...
@@ -317,12 +199,35 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
const
int
batch_size
=
kx
.
size
(
0
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
torch
::
Tensor
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
torch
::
Tensor
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
torch
::
Tensor
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
();
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
kxP
=
kx
;
}
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
vxP
=
vx
;
}
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
qyP
=
qy
;
}
cudaDeviceSynchronize
();
nvtxRangePop
();
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
...
...
@@ -338,7 +243,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_kernel
_mbT
<
THREADS
>
s2_attention_kernel
<
THREADS
>
<<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
...
...
@@ -355,6 +260,9 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
// match output layout to input layout
if
(
!
q_channel_first
)
y
=
y
.
contiguous
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
y
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment