Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
6512d042
Commit
6512d042
authored
Jun 11, 2025
by
Max Rietmann
Browse files
Removed all stale backwards kernel code
Also match the gradient output to the input, in terms of memory layout
parent
4096e64b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
935 deletions
+107
-935
tests/test_attention.py
tests/test_attention.py
+1
-1
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+0
-27
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+106
-903
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+0
-4
No files found.
tests/test_attention.py
View file @
6512d042
...
...
@@ -289,7 +289,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
v_gpu
.
requires_grad
=
True
out_gpu
=
att_gpu
(
q_gpu
,
k_gpu
,
v_gpu
)
out_grad
=
torch
.
randn
(
out_gpu
.
shape
,
dtype
=
torch
.
float32
,
device
=
"cuda:0"
)
.
to
(
memory_format
=
torch
.
channels_last
)
out_grad
=
torch
.
randn
(
out_gpu
.
shape
,
dtype
=
torch
.
float32
,
device
=
"cuda:0"
)
time_backward_start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
time_backward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
...
torch_harmonics/csrc/attention/attention.cuh
View file @
6512d042
...
...
@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch
::
Tensor
s2_attention_bwd_dq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch
::
Tensor
s2_attention_bwd_dk_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch
::
Tensor
s2_attention_bwd_dv_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
6512d042
...
...
@@ -116,634 +116,14 @@ __device__ float __warp_sum_cub(float val) {
return
sum
;
}
__global__
void
s2_attention_bwd_dv_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
// shared memory
extern
__shared__
float
sharedMem
[];
float
*
sh_alpha_sum
=
(
float
*
)
&
sharedMem
;
// 1
float
*
sh_qdotk_max
=
(
float
*
)
&
sharedMem
[
1
];
// 1
float
*
sh_qy_ho_wo
=
(
float
*
)
&
sharedMem
[
2
];
// num_channels
if
(
threadIdx
.
x
==
0
)
{
sh_qdotk_max
[
0
]
=
std
::
numeric_limits
<
float
>::
lowest
();
sh_alpha_sum
[
0
]
=
0.0
;
}
__syncthreads
();
int
ho
=
blockIdx
.
x
;
int
wo
=
blockIdx
.
y
;
int
batch_b
=
blockIdx
.
z
;
// load qy channels into shared memory
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
sh_qy_ho_wo
[
channel_idx
]
=
qy
[
batch_b
][
channel_idx
][
ho
][
wo
];
}
__syncthreads
();
int
psi_offset
=
psi_row_offset
[
ho
];
int
psi_nnz_ho
=
psi_row_offset
[
ho
+
1
]
-
psi_offset
;
float
qdotk_max
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
qdotk_max
=
std
::
max
(
qdotk
,
qdotk_max
);
}
// collect thread-local qdotk max
atomicMax
(
&
sh_qdotk_max
[
0
],
qdotk_max
);
__syncthreads
();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max
=
sh_qdotk_max
[
0
];
// form alpha & sum alpha
float
alpha_sum
=
0.0
;
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// softmax numerator
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// sum alpha
alpha_sum
+=
alpha_inz
;
}
// collect thread-local alpha_sum
atomicAdd
(
&
sh_alpha_sum
[
0
],
alpha_sum
);
__syncthreads
();
// "broadcast" alpha sum back to thread-local registers
alpha_sum
=
sh_alpha_sum
[
0
];
// alpha * dy * omega / alpha_sum
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// multiply alpha/sum_alpha, dy, and quadrature weights
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
dydv
[
batch_b
][
channel_idx
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
dy
[
batch_b
][
channel_idx
][
ho
][
wo
]);
}
}
}
at
::
Tensor
s2_attention_bwd_dv_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
torch
::
Tensor
dydv
=
torch
::
zeros_like
(
vx
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
sharedMemSize
=
(
uo_num_channels
+
2
)
*
sizeof
(
float
);
const
int
batch_size
=
kx
.
size
(
0
);
// cuda grid y,z size limitations
assert
(
nlon_out
<
65535
);
assert
(
batch_size
<
65535
);
// block-parallel over output points and batches
dim3
gridDim
(
nlat_out
,
nlon_out
,
batch_size
);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3
blockDim
(
256
,
1
,
1
);
s2_attention_bwd_dv_kernel
<<<
gridDim
,
blockDim
,
sharedMemSize
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
()
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
dydv
;
}
__global__
void
s2_attention_bwd_dk_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
// shared memory
extern
__shared__
float
sharedMem
[];
float
*
sh_alpha_sum
=
(
float
*
)
&
sharedMem
;
float
*
sh_qy_ho_wo
=
(
float
*
)
&
sharedMem
[
1
];
float
*
sh_integral
=
(
float
*
)
&
sharedMem
[
1
+
num_channels
];
float
*
sh_dy_ho_wo
=
(
float
*
)
&
sharedMem
[
2
+
num_channels
];
float
*
sh_qdotk_max
=
(
float
*
)
&
sharedMem
[
2
+
2
*
num_channels
];
if
(
threadIdx
.
x
==
0
)
{
sh_alpha_sum
[
0
]
=
0.0
;
sh_integral
[
0
]
=
0.0
;
sh_qdotk_max
[
0
]
=
std
::
numeric_limits
<
float
>::
lowest
();
}
__syncthreads
();
int
ho
=
blockIdx
.
x
;
int
wo
=
blockIdx
.
y
;
int
batch_b
=
blockIdx
.
z
;
// load qy channels into shared memory
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
sh_qy_ho_wo
[
channel_idx
]
=
qy
[
batch_b
][
channel_idx
][
ho
][
wo
];
sh_dy_ho_wo
[
channel_idx
]
=
dy
[
batch_b
][
channel_idx
][
ho
][
wo
];
}
__syncthreads
();
int
psi_offset
=
psi_row_offset
[
ho
];
int
psi_nnz_ho
=
psi_row_offset
[
ho
+
1
]
-
psi_offset
;
float
qdotk_max
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
}
// compute max over all threads
atomicMax
(
&
sh_qdotk_max
[
0
],
qdotk_max
);
__syncthreads
();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max
=
sh_qdotk_max
[
0
];
float
alpha_sum
=
0.0
;
float
integral
=
0.0
;
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
gdotv
=
0.0
;
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
gdotv
+=
sh_dy_ho_wo
[
channel_idx
]
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
];
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
// softmax numerator
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// sum alpha & integral
alpha_sum
+=
alpha_inz
;
integral
+=
alpha_inz
*
gdotv
;
}
// block sum thread-local alpha_sum and integral
atomicAdd
(
&
sh_alpha_sum
[
0
],
alpha_sum
);
atomicAdd
(
&
sh_integral
[
0
],
integral
);
__syncthreads
();
// finish integral computation
if
(
threadIdx
.
x
==
0
)
sh_integral
[
0
]
/=
sh_alpha_sum
[
0
];
__syncthreads
();
// broadcast sum and integral back to thread-local registers
integral
=
sh_integral
[
0
];
alpha_sum
=
sh_alpha_sum
[
0
];
// divide output by alpha_sum
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
float
gdotv
=
0.0
;
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
gdotv
+=
sh_dy_ho_wo
[
channel_idx
]
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
];
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// multiply alpha/sum_alpha, vx, and quadrature weights
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
dydk
[
batch_b
][
channel_idx
][
hi
][
wip
],
sh_qy_ho_wo
[
channel_idx
]
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
}
}
__syncthreads
();
}
__global__
void
s2_attention_bwd_dq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
// shared memory
extern
__shared__
float
sharedMem
[];
float
*
sh_alpha_sum
=
(
float
*
)
&
sharedMem
;
float
*
sh_qy_ho_wo
=
(
float
*
)
&
sharedMem
[
1
];
float
*
sh_alpha_k
=
(
float
*
)
&
sharedMem
[
1
+
num_channels
];
float
*
sh_alpha_vw
=
(
float
*
)
&
sharedMem
[
1
+
2
*
num_channels
];
float
*
sh_alpha_kvw
=
(
float
*
)
&
sharedMem
[
1
+
3
*
num_channels
];
float
*
sh_dy_ho_wo
=
(
float
*
)
&
sharedMem
[
1
+
4
*
num_channels
];
float
*
sh_qdotk_max
=
(
float
*
)
&
sharedMem
[
1
+
5
*
num_channels
];
if
(
threadIdx
.
x
==
0
)
{
sh_alpha_sum
[
0
]
=
0.0
;
sh_qdotk_max
[
0
]
=
std
::
numeric_limits
<
float
>::
lowest
();
}
__syncthreads
();
int
ho
=
blockIdx
.
x
;
int
wo
=
blockIdx
.
y
;
int
batch_b
=
blockIdx
.
z
;
// load qy channels into shared memory and zero temporary variables
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
sh_qy_ho_wo
[
channel_idx
]
=
qy
[
batch_b
][
channel_idx
][
ho
][
wo
];
sh_dy_ho_wo
[
channel_idx
]
=
dy
[
batch_b
][
channel_idx
][
ho
][
wo
];
sh_alpha_k
[
channel_idx
]
=
0.0
f
;
sh_alpha_vw
[
channel_idx
]
=
0.0
f
;
sh_alpha_kvw
[
channel_idx
]
=
0.0
f
;
}
__syncthreads
();
int
psi_offset
=
psi_row_offset
[
ho
];
int
psi_nnz_ho
=
psi_row_offset
[
ho
+
1
]
-
psi_offset
;
float
qdotk_max
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
f
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
qdotk_max
=
std
::
max
(
qdotk
,
qdotk_max
);
}
atomicMax
(
&
sh_qdotk_max
[
0
],
qdotk_max
);
__syncthreads
();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max
=
sh_qdotk_max
[
0
];
float
alpha_sum
=
0.0
;
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
f
;
float
gdotv
=
0.0
f
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
gdotv
+=
sh_dy_ho_wo
[
channel_idx
]
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
];
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
// softmax numerator
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// sum alpha
alpha_sum
+=
alpha_inz
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
sh_alpha_k
[
channel_idx
],
alpha_inz
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
]);
atomicAdd
(
&
sh_alpha_vw
[
channel_idx
],
alpha_inz
*
gdotv
);
atomicAdd
(
&
sh_alpha_kvw
[
channel_idx
],
alpha_inz
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
]
*
gdotv
);
}
}
// sum thread-local alpha_sums across block
atomicAdd
(
&
sh_alpha_sum
[
0
],
alpha_sum
);
__syncthreads
();
// "broadcast" alpha sum back to thread-local registers
alpha_sum
=
sh_alpha_sum
[
0
];
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
dydq
[
batch_b
][
channel_idx
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
channel_idx
]
*
sh_alpha_sum
[
0
]
-
sh_alpha_vw
[
channel_idx
]
*
sh_alpha_k
[
channel_idx
])
/
(
alpha_sum
*
alpha_sum
);
}
}
__global__
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
// shared memory
extern
__shared__
float
sharedMem
[];
float
*
sh_alpha_sum
=
(
float
*
)
&
sharedMem
;
float
*
sh_integral
=
(
float
*
)
&
sharedMem
[
1
];
float
*
sh_qy_ho_wo
=
(
float
*
)
&
sharedMem
[
2
];
float
*
sh_alpha_k
=
(
float
*
)
&
sharedMem
[
2
+
num_channels
];
float
*
sh_alpha_vw
=
(
float
*
)
&
sharedMem
[
2
+
2
*
num_channels
];
float
*
sh_alpha_kvw
=
(
float
*
)
&
sharedMem
[
2
+
3
*
num_channels
];
float
*
sh_dy_ho_wo
=
(
float
*
)
&
sharedMem
[
2
+
4
*
num_channels
];
float
*
sh_qdotk_max
=
(
float
*
)
&
sharedMem
[
2
+
5
*
num_channels
];
if
(
threadIdx
.
x
==
0
)
{
sh_alpha_sum
[
0
]
=
0.0
;
sh_integral
[
0
]
=
0.0
;
sh_qdotk_max
[
0
]
=
std
::
numeric_limits
<
float
>::
lowest
();
}
__syncthreads
();
int
ho
=
blockIdx
.
x
;
int
wo
=
blockIdx
.
y
;
int
batch_b
=
blockIdx
.
z
;
// load qy channels into shared memory and zero temporary variables
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
sh_qy_ho_wo
[
channel_idx
]
=
qy
[
batch_b
][
channel_idx
][
ho
][
wo
];
sh_dy_ho_wo
[
channel_idx
]
=
dy
[
batch_b
][
channel_idx
][
ho
][
wo
];
sh_alpha_k
[
channel_idx
]
=
0.0
f
;
sh_alpha_vw
[
channel_idx
]
=
0.0
f
;
sh_alpha_kvw
[
channel_idx
]
=
0.0
f
;
}
__syncthreads
();
int
psi_offset
=
psi_row_offset
[
ho
];
int
psi_nnz_ho
=
psi_row_offset
[
ho
+
1
]
-
psi_offset
;
float
qdotk_max
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
f
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
qdotk_max
=
std
::
max
(
qdotk
,
qdotk_max
);
}
atomicMax
(
&
sh_qdotk_max
[
0
],
qdotk_max
);
__syncthreads
();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max
=
sh_qdotk_max
[
0
];
float
alpha_sum
=
0.0
;
float
integral
=
0.0
;
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
// correlation Q&K (dot-product Q.K)
float
qdotk
=
0.0
f
;
float
gdotv
=
0.0
f
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
gdotv
+=
sh_dy_ho_wo
[
channel_idx
]
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
];
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
// softmax numerator
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// sum alpha
alpha_sum
+=
alpha_inz
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
sh_alpha_k
[
channel_idx
],
alpha_inz
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
]);
atomicAdd
(
&
sh_alpha_vw
[
channel_idx
],
alpha_inz
*
gdotv
);
atomicAdd
(
&
sh_alpha_kvw
[
channel_idx
],
alpha_inz
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
]
*
gdotv
);
}
integral
+=
alpha_inz
*
gdotv
;
}
// sum thread-local alpha_sums & integral across block
atomicAdd
(
&
sh_alpha_sum
[
0
],
alpha_sum
);
atomicAdd
(
&
sh_integral
[
0
],
integral
);
__syncthreads
();
// finalize integral
if
(
threadIdx
.
x
==
0
)
sh_integral
[
0
]
/=
sh_alpha_sum
[
0
];
__syncthreads
();
// "broadcast" alpha sum & integral back to thread-local registers
alpha_sum
=
sh_alpha_sum
[
0
];
integral
=
sh_integral
[
0
];
// dq
for
(
int
channel_block_i
=
0
;
channel_block_i
<
(
num_channels
/
blockDim
.
x
)
+
1
;
channel_block_i
++
)
{
int
channel_idx
=
channel_block_i
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
channel_idx
>=
num_channels
)
break
;
dydq
[
batch_b
][
channel_idx
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
channel_idx
]
*
sh_alpha_sum
[
0
]
-
sh_alpha_vw
[
channel_idx
]
*
sh_alpha_k
[
channel_idx
])
/
(
alpha_sum
*
alpha_sum
);
}
__syncthreads
();
// dk & dv
for
(
int
psi_block
=
0
;
psi_block
<
(
psi_nnz_ho
/
blockDim
.
x
)
+
1
;
psi_block
++
)
{
int
idz
=
psi_block
*
blockDim
.
x
+
threadIdx
.
x
;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if
(
idz
>=
psi_nnz_ho
)
break
;
int
nz_col_idx
=
psi_col_idx
[
psi_offset
+
idz
];
// compute input indices from psi datastructure
int
hi
=
nz_col_idx
/
nlon_in
;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int
wi
=
nz_col_idx
%
nlon_in
;
int
wip
=
(
wi
+
wo
)
%
nlon_in
;
float
gdotv
=
0.0
;
float
qdotk
=
0.0
;
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
gdotv
+=
sh_dy_ho_wo
[
channel_idx
]
*
vx
[
batch_b
][
channel_idx
][
hi
][
wip
];
qdotk
+=
sh_qy_ho_wo
[
channel_idx
]
*
kx
[
batch_b
][
channel_idx
][
hi
][
wip
];
}
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// multiply alpha/sum_alpha, vx, and quadrature weights
for
(
int
channel_idx
=
0
;
channel_idx
<
num_channels
;
channel_idx
++
)
{
atomicAdd
(
&
dydk
[
batch_b
][
channel_idx
][
hi
][
wip
],
sh_qy_ho_wo
[
channel_idx
]
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
atomicAdd
(
&
dydv
[
batch_b
][
channel_idx
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
sh_dy_ho_wo
[
channel_idx
]);
}
}
__syncthreads
();
}
// New kernel: s2_attention_bwd_dkvq_kernel_mbT
// This kernel assumes kx, vx, qy, dy, dydk, dydv, dydq are all [batch, ho, wo, channel] (transposed)
// This kernel computes the backward pass for the S2 attention mechanism, using
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
_mbT
(
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
...
...
@@ -859,116 +239,8 @@ __launch_bounds__(BDIM_X)
}
}
at
::
Tensor
s2_attention_bwd_dk_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
torch
::
Tensor
dydk
=
torch
::
zeros_like
(
kx
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
sharedMemSize
=
(
2
*
uo_num_channels
+
3
)
*
sizeof
(
float
);
const
int
batch_size
=
kx
.
size
(
0
);
// cuda grid y,z size limitations
assert
(
nlon_out
<
65535
);
assert
(
batch_size
<
65535
);
// block-parallel over output points and batches
dim3
gridDim
(
nlat_out
,
nlon_out
,
batch_size
);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3
blockDim
(
256
,
1
,
1
);
s2_attention_bwd_dk_kernel
<<<
gridDim
,
blockDim
,
sharedMemSize
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
()
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
dydk
;
}
at
::
Tensor
s2_attention_bwd_dq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
torch
::
Tensor
dydq
=
torch
::
zeros_like
(
qy
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
sharedMemSize
=
(
5
*
uo_num_channels
+
2
)
*
sizeof
(
float
);
const
int
batch_size
=
kx
.
size
(
0
);
// cuda grid y,z size limitations
assert
(
nlon_out
<
65535
);
assert
(
batch_size
<
65535
);
// block-parallel over output points and batches
dim3
gridDim
(
nlat_out
,
nlon_out
,
batch_size
);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3
blockDim
(
256
,
1
,
1
);
s2_attention_bwd_dq_kernel
<<<
gridDim
,
blockDim
,
sharedMemSize
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
()
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
dydq
;
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
...
...
@@ -988,183 +260,114 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
// enum for which kernel version
enum
KERNEL_VERSION
{
OLD_VERSION
=
0
,
HOWO_WARP_VERSION
=
2
,
};
auto
version
=
HOWO_WARP_VERSION
;
// auto version = OLD_VERSION;
if
(
version
==
OLD_VERSION
)
{
// printf("old version\n");
torch
::
Tensor
dydk
=
torch
::
zeros_like
(
qy
);
torch
::
Tensor
dydv
=
torch
::
zeros_like
(
qy
);
torch
::
Tensor
dydq
=
torch
::
zeros_like
(
qy
);
size_t
sharedMemSize
=
(
6
*
uo_num_channels
+
3
)
*
sizeof
(
float
);
// cuda grid y,z size limitations
assert
(
nlon_out
<
65535
);
assert
(
batch_size
<
65535
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
// block-parallel over output points and batches
dim3
gridDim
(
nlat_out
,
nlon_out
,
batch_size
);
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3
blockDim
(
256
,
1
,
1
);
// Define CUDA event variables for timing
cudaEvent_t
start_event
,
stop_event
;
cudaEventCreate
(
&
start_event
);
cudaEventCreate
(
&
stop_event
);
// Record the start event
cudaEventRecord
(
start_event
,
stream
);
s2_attention_bwd_dkvq_kernel
<<<
gridDim
,
blockDim
,
sharedMemSize
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vx
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dy
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
()
);
// Record the stop event
cudaEventRecord
(
stop_event
,
stream
);
cudaEventSynchronize
(
stop_event
);
// Calculate elapsed time
float
kernel_time_ms
;
cudaEventElapsedTime
(
&
kernel_time_ms
,
start_event
,
stop_event
);
// Output the result
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// Old bwd kernel execution time: 803.477 ms
// std::cout << "Old bwd kernel execution time: " << kernel_time_ms << " ms" << std::endl;
// Cleanup events
cudaEventDestroy
(
start_event
);
cudaEventDestroy
(
stop_event
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
else
if
(
version
==
HOWO_WARP_VERSION
)
{
// ScopeTimer timer("Full s2_attention_bwd_dkvq_kernel_mbT");
// Time this function via C++
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
kxP
=
kx
;
}
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
vxP
=
vx
;
}
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
qyP
=
qy
;
}
auto
dyP
=
at
::
Tensor
();
if
(
!
dy_channel_first
)
{
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
dyP
=
dy
;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop
();
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
auto
dydkP
=
torch
::
zeros_like
(
qyP
);
auto
dydvP
=
torch
::
zeros_like
(
qyP
);
auto
dydqP
=
torch
::
zeros_like
(
qyP
);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
kxP
=
kx
;
}
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
vxP
=
vx
;
}
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
qyP
=
qy
;
}
auto
dyP
=
at
::
Tensor
();
if
(
!
dy_channel_first
)
{
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
dyP
=
dy
;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop
();
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
auto
dydkP
=
torch
::
zeros_like
(
qyP
);
auto
dydvP
=
torch
::
zeros_like
(
qyP
);
auto
dydqP
=
torch
::
zeros_like
(
qyP
);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop
();
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
s2_attention_bwd_dkvq_kernel_mbT
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydkP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydvP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydqP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydkP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydvP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydqP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
at
::
Tensor
dydk
,
dydv
,
dydq
;
if
(
!
k_channel_first
)
dydk
=
dydkP
.
contiguous
();
else
dydk
=
dydkP
;
if
(
!
v_channel_first
)
dydv
=
dydvP
.
contiguous
();
else
dydv
=
dydvP
;
if
(
!
q_channel_first
)
dydq
=
dydqP
.
contiguous
();
else
dydq
=
dydqP
;
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
// Permute outputs back to [batch, channel, ho, wo]
// nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output permutation");
// auto* permute_output_timer = new ScopeTimer("permute outputs");
// auto dydk = dydkP.permute({0,3,1,2}).contiguous().permute({0,3,1,2});
// auto dydv = dydvP.permute({0,3,1,2}).contiguous();
// auto dydq = dydqP.permute({0, 3, 1, 2}).contiguous();
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydkP
,
dydvP
,
dydqP
);
}
else
{
throw
std
::
runtime_error
(
"Invalid kernel version specified"
);
}
}
torch_harmonics/csrc/attention/attention_interface.cu
View file @
6512d042
...
...
@@ -33,10 +33,6 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"backward_dk"
,
&
s2_attention_bwd_dk_cuda
,
"(Local) Attention gradient on S2 (gradient for k)"
);
m
.
def
(
"backward_dv"
,
&
s2_attention_bwd_dv_cuda
,
"(Local) Attention gradient on S2 (gradient for v)"
);
m
.
def
(
"backward_dq"
,
&
s2_attention_bwd_dq_cuda
,
"(Local) Attention gradient on S2 (gradient for q)"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment