Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
4aaff021
Unverified
Commit
4aaff021
authored
Jul 17, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 17, 2025
Browse files
Merge pull request #91 from NVIDIA/maurob/devel
Attention Backward improvement
parents
ab44ba59
fa58767d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1518 additions
and
656 deletions
+1518
-656
setup.py
setup.py
+2
-0
tests/test_attention.py
tests/test_attention.py
+2
-3
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+15
-0
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+5
-1
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+753
-115
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+132
-532
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+180
-0
torch_harmonics/csrc/attention/attention_utils.cuh
torch_harmonics/csrc/attention/attention_utils.cuh
+373
-0
torch_harmonics/csrc/attention/cudamacro.h
torch_harmonics/csrc/attention/cudamacro.h
+47
-0
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+9
-5
No files found.
setup.py
View file @
4aaff021
...
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags
=
[]
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-Xptxas=-v"
)
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
...
...
@@ -102,6 +103,7 @@ def get_ext_modules():
CUDAExtension
(
name
=
"attention_cuda_extension"
,
sources
=
[
"torch_harmonics/csrc/attention/attention_utils.cu"
,
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_interface.cu"
,
...
...
tests/test_attention.py
View file @
4aaff021
...
...
@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
2
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
1
,
1
,
(
2
,
4
),
(
2
,
4
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-5
,
1e-3
],
],
skip_on_empty
=
True
,
...
...
@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-2
,
0
],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-2
,
0
],
],
...
...
torch_harmonics/_neighborhood_attention.py
View file @
4aaff021
...
...
@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B
,
_
,
H
,
W
=
grad_output
.
shape
grad_output
=
grad_output
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
# save type and convert to float32
kw_dtype
=
kw
.
dtype
vw_dtype
=
vw
.
dtype
qw_dtype
=
qw
.
dtype
kw
=
kw
.
to
(
torch
.
float32
).
contiguous
()
vw
=
vw
.
to
(
torch
.
float32
).
contiguous
()
qw
=
qw
.
to
(
torch
.
float32
).
contiguous
()
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
dkw
,
dvw
,
dqw
=
attention_cuda_extension
.
backward_dkvq
(
kw
,
vw
,
qw
,
grad_output
,
quad_weights
,
col_idx
,
row_off
,
...
...
@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_
,
C
,
H
,
W
=
dqw
.
shape
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
# convert precision
dkw
=
dkw
.
to
(
dtype
=
kw_dtype
)
dvw
=
dvw
.
to
(
dtype
=
vw_dtype
)
dqw
=
dqw
.
to
(
dtype
=
qw_dtype
)
# input grads
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
...
...
torch_harmonics/csrc/attention/attention.cuh
View file @
4aaff021
...
...
@@ -34,7 +34,11 @@
#include <cstdint>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
4aaff021
...
...
@@ -41,33 +41,18 @@
#include <cub/cub.cuh>
#include <limits>
#ifndef WARP_SIZE
#define WARP_SIZE (32)
#endif
#ifndef FULL_MASK
#define FULL_MASK (0xFFFFFFFF)
#endif
#ifndef THREADS
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#include "cudamacro.h"
#include "attention_utils.cuh"
#include <iostream>
#include <chrono>
#include <string>
#define THREADS (64)
#define MAX_LOCAL_ARR_LEN (16)
#if 0
class ScopeTimer
{
public:
...
...
@@ -88,13 +73,6 @@ class ScopeTimer
std::chrono::high_resolution_clock::time_point start_;
};
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val)
{
...
...
@@ -216,6 +194,697 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
}
}
#endif
// BEGIN backward kernels and functions
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template
<
int
BDIM_X
,
typename
FLOATV_T
>
// either float or float4
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attn_bwd_generic_vec_k
(
int
nchan
,
// no. of FLOATV_T elements along channel dim
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
FLOATV_T
*
__restrict__
kx
,
const
FLOATV_T
*
__restrict__
vx
,
const
FLOATV_T
*
__restrict__
qy
,
const
FLOATV_T
*
__restrict__
dy
,
const
int32_t
*
__restrict__
row_idx
,
const
int64_t
*
__restrict__
row_off
,
const
int64_t
*
__restrict__
col_idx
,
const
float
*
__restrict__
quad_weights
,
FLOATV_T
*
__restrict__
dkx
,
FLOATV_T
*
__restrict__
dvx
,
FLOATV_T
*
__restrict__
dqy
)
{
extern
__shared__
__align__
(
sizeof
(
float4
))
float
shext
[];
// for dqy
FLOATV_T
*
sh_alpha_k__
=
reinterpret_cast
<
FLOATV_T
*>
(
shext
)
+
threadIdx
.
y
*
nchan
*
5
;
FLOATV_T
*
sh_alpha_vw_
=
sh_alpha_k__
+
nchan
;
FLOATV_T
*
sh_alpha_kvw
=
sh_alpha_vw_
+
nchan
;
FLOATV_T
*
sh_dy
=
sh_alpha_kvw
+
nchan
;
FLOATV_T
*
sh_qy
=
sh_dy
+
nchan
;
const
int
batch
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
const
int
tidx
=
threadIdx
.
x
;
// use permuted rows
const
int
h
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
h
*
nlon_out
);
const
int
ho
=
row_idx
[
h
];
// offset input tensors
kx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
vx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
qy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
;
dy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
;
// offset output tensors
dkx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
dvx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
dqy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
;
// zero/init shared memory
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
sh_alpha_k__
[
chan
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
sh_alpha_vw_
[
chan
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
sh_alpha_kvw
[
chan
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
sh_dy
[
chan
]
=
dy
[
chan
];
sh_qy
[
chan
]
=
qy
[
chan
];
}
#if __CUDA_ARCH__ < 900
// for architectures < 9.0, sh_dy and sh_qy will be read
// as individual floats at the end of the kernel, which
// breaks the assumption that each FLOATV_T location is
// written to and read by the same thread throughout the
// kernel, in the case FLOATV_T==float4
if
constexpr
(
std
::
is_same
<
FLOATV_T
,
float4
>::
value
)
{
__syncwarp
();
}
#endif
// for dkx, dvx, dqy
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
// for dkx
float
integral
=
0.0
f
;
const
int64_t
rbeg
=
row_off
[
ho
];
const
int64_t
rend
=
row_off
[
ho
+
1
];
col_idx
+=
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
// accumulate alpha_sum, integral, and shared stats,
// along with a progressively computed qdotk_max.
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotk_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
FLOATV_T
gdotv_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
chan
],
_kx
[
chan
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
chan
],
_vx
[
chan
]));
}
const
float
qdotk
=
__warp_sum
(
__vred
(
qdotk_v
));
const
float
gdotv
=
__warp_sum
(
__vred
(
gdotv_v
));
const
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
const
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
const
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
const
float
ainz_gdotv
=
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
const
FLOATV_T
kxval
=
_kx
[
chan
];
sh_alpha_k__
[
chan
]
=
__vadd
(
__vscale
(
max_correction
,
sh_alpha_k__
[
chan
]),
__vscale
(
alpha_inz
,
kxval
));
sh_alpha_vw_
[
chan
]
=
__vadd
(
__vscale
(
max_correction
,
sh_alpha_vw_
[
chan
]),
__vset
<
FLOATV_T
>
(
ainz_gdotv
));
sh_alpha_kvw
[
chan
]
=
__vadd
(
__vscale
(
max_correction
,
sh_alpha_kvw
[
chan
]),
__vscale
(
ainz_gdotv
,
kxval
));
}
qdotk_max
=
qdotk_max_tmp
;
}
const
float
alpha_sum_inv
=
1.0
f
/
alpha_sum
;
integral
*=
alpha_sum_inv
;
// Write dqy
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
dqy
[
chan
]
=
__vscale
(
alpha_sum_inv
*
alpha_sum_inv
,
__vsub
(
__vscale
(
alpha_sum
,
sh_alpha_kvw
[
chan
]),
__vmul
(
sh_alpha_vw_
[
chan
],
sh_alpha_k__
[
chan
])));
}
// accumulate gradients for k and v
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotk_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
FLOATV_T
gdotv_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
chan
],
_kx
[
chan
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
chan
],
_vx
[
chan
]));
}
const
float
qdotk
=
__warp_sum
(
__vred
(
qdotk_v
));
const
float
gdotv
=
__warp_sum
(
__vred
(
gdotv_v
));
const
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
FLOATV_T
*
_dkx
=
dkx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
*
_dvx
=
dvx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
float
alpha_mul
=
alpha_inz
*
alpha_sum_inv
;
const
float
scale_fact_qy
=
(
gdotv
-
integral
)
*
alpha_mul
;
const
float
scale_fact_dy
=
alpha_mul
;
// float4, 128-bit atomics are only supported by devices of compute
// capability 9.x+, so on older devices we resort to 32-bit atomics
#if __CUDA_ARCH__ < 900
// to use 32-bit operations on consecutve addresses
float
*
sh_qy_scl
=
reinterpret_cast
<
float
*>
(
sh_qy
);
float
*
sh_dy_scl
=
reinterpret_cast
<
float
*>
(
sh_dy
);
float
*
_dkx_scl
=
reinterpret_cast
<
float
*>
(
_dkx
);
float
*
_dvx_scl
=
reinterpret_cast
<
float
*>
(
_dvx
);
constexpr
int
VEC_SIZE
=
sizeof
(
FLOATV_T
)
/
sizeof
(
float
);
// 32-bit, consecutive atomics to glmem;
// strided atomics results in a severe slowdown
for
(
int
chan
=
tidx
;
chan
<
nchan
*
VEC_SIZE
;
chan
+=
WARP_SIZE
)
{
atomicAdd
(
_dkx_scl
+
chan
,
scale_fact_qy
*
sh_qy_scl
[
chan
]);
atomicAdd
(
_dvx_scl
+
chan
,
scale_fact_dy
*
sh_dy_scl
[
chan
]);
}
#else
// 128-bit, consecutive atomics to glmem
for
(
int
chan
=
tidx
;
chan
<
nchan
;
chan
+=
WARP_SIZE
)
{
atomicAdd
(
_dkx
+
chan
,
__vscale
(
scale_fact_qy
,
sh_qy
[
chan
]));
atomicAdd
(
_dvx
+
chan
,
__vscale
(
scale_fact_dy
,
sh_dy
[
chan
]));
}
#endif
}
return
;
}
// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
template
<
int
BDIM_X
,
int
BDIM_Y
,
int
NLOC
,
typename
FLOATV_T
>
// either float or float4
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
s2_attn_bwd_special_vec_k
(
int
nchan
,
// no. of FLOATV_T elements along channel dim
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
FLOATV_T
*
__restrict__
kx
,
const
FLOATV_T
*
__restrict__
vx
,
const
FLOATV_T
*
__restrict__
qy
,
const
FLOATV_T
*
__restrict__
dy
,
const
int32_t
*
__restrict__
row_idx
,
const
int64_t
*
__restrict__
row_off
,
const
int64_t
*
__restrict__
col_idx
,
const
float
*
__restrict__
quad_weights
,
FLOATV_T
*
__restrict__
dkx
,
FLOATV_T
*
__restrict__
dvx
,
FLOATV_T
*
__restrict__
dqy
)
{
static_assert
(
0
==
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
0
==
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
((
BDIM_X
==
32
&&
BDIM_Y
>
1
)
||
(
BDIM_X
>
32
&&
BDIM_Y
==
1
))
;
constexpr
int
NLOC_M1
=
NLOC
-
1
;
const
int
tidx
=
threadIdx
.
x
;
const
int
batch
=
blockIdx
.
y
;
const
uint64_t
ctaid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
ctaid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
extern
__shared__
__align__
(
sizeof
(
float4
))
float
shext
[];
FLOATV_T
*
sh_dy
=
reinterpret_cast
<
FLOATV_T
*>
(
shext
)
+
threadIdx
.
y
*
nchan
*
2
+
tidx
;
FLOATV_T
*
sh_qy
=
sh_dy
+
nchan
;
// for dqy
FLOATV_T
loc_k__
[
NLOC
];
FLOATV_T
loc_vw_
[
NLOC
];
FLOATV_T
loc_kvw
[
NLOC
];
// use permuted rows
const
int
h
=
ctaid
/
nlon_out
;
const
int
wo
=
ctaid
-
(
h
*
nlon_out
);
const
int
ho
=
row_idx
[
h
];
// offset input tensors
kx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
vx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
qy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
+
tidx
;
dy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
+
tidx
;
// offset output tensors
dkx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
dvx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
dqy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
+
tidx
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC
;
i
++
)
{
loc_k__
[
i
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
loc_vw_
[
i
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
loc_kvw
[
i
]
=
__vset
<
FLOATV_T
>
(
0.0
f
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
sh_dy
[
i
*
BDIM_X
]
=
dy
[
i
*
BDIM_X
];
sh_qy
[
i
*
BDIM_X
]
=
qy
[
i
*
BDIM_X
];
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
sh_dy
[
NLOC_M1
*
BDIM_X
]
=
dy
[
NLOC_M1
*
BDIM_X
];
sh_qy
[
NLOC_M1
*
BDIM_X
]
=
qy
[
NLOC_M1
*
BDIM_X
];
}
#if __CUDA_ARCH__ < 900
// for architectures < 9.0, sh_dy and sh_qy will be read
// as individual floats at the end of the kernel, which
// breaks the assumption that each FLOATV_T location is
// written to and read by the same thread throughout the
// kernel, in the case FLOATV_T==float4
if
constexpr
(
std
::
is_same
<
FLOATV_T
,
float4
>::
value
)
{
if
constexpr
(
BDIM_X
==
32
)
{
__syncwarp
();
}
else
{
__syncthreads
();
}
}
#endif
// for dkx, dvx, dqy
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
// for dkx
float
integral
=
0.0
f
;
const
int64_t
rbeg
=
row_off
[
ho
];
const
int64_t
rend
=
row_off
[
ho
+
1
];
col_idx
+=
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
// accumulate alpha_sum, integral, and shared stats,
// along with a progressively computed qdotk_max.
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotk_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
FLOATV_T
gdotv_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
i
*
BDIM_X
],
_kx
[
i
*
BDIM_X
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
i
*
BDIM_X
],
_vx
[
i
*
BDIM_X
]));
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
NLOC_M1
*
BDIM_X
],
_kx
[
NLOC_M1
*
BDIM_X
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
NLOC_M1
*
BDIM_X
],
_vx
[
NLOC_M1
*
BDIM_X
]));
}
float
qdotk
=
__vred
(
qdotk_v
);
float
gdotv
=
__vred
(
gdotv_v
);
if
constexpr
(
BDIM_X
==
32
)
{
qdotk
=
__warp_sum
(
qdotk
);
gdotv
=
__warp_sum
(
gdotv
);
}
else
{
qdotk
=
__block_sum
<
BDIM_X
>
(
qdotk
);
gdotv
=
__block_sum
<
BDIM_X
>
(
gdotv
);
}
const
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
const
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
const
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
const
float
ainz_gdotv
=
alpha_inz
*
gdotv
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
const
FLOATV_T
kxval
=
_kx
[
i
*
BDIM_X
];
loc_k__
[
i
]
=
__vadd
(
__vscale
(
max_correction
,
loc_k__
[
i
]),
__vscale
(
alpha_inz
,
kxval
));
loc_vw_
[
i
]
=
__vadd
(
__vscale
(
max_correction
,
loc_vw_
[
i
]),
__vset
<
FLOATV_T
>
(
ainz_gdotv
));
loc_kvw
[
i
]
=
__vadd
(
__vscale
(
max_correction
,
loc_kvw
[
i
]),
__vscale
(
ainz_gdotv
,
kxval
));
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
const
FLOATV_T
kxval
=
_kx
[
NLOC_M1
*
BDIM_X
];
loc_k__
[
NLOC_M1
]
=
__vadd
(
__vscale
(
max_correction
,
loc_k__
[
NLOC_M1
]),
__vscale
(
alpha_inz
,
kxval
));
loc_vw_
[
NLOC_M1
]
=
__vadd
(
__vscale
(
max_correction
,
loc_vw_
[
NLOC_M1
]),
__vset
<
FLOATV_T
>
(
ainz_gdotv
));
loc_kvw
[
NLOC_M1
]
=
__vadd
(
__vscale
(
max_correction
,
loc_kvw
[
NLOC_M1
]),
__vscale
(
ainz_gdotv
,
kxval
));
}
qdotk_max
=
qdotk_max_tmp
;
}
const
float
alpha_sum_inv
=
1.0
f
/
alpha_sum
;
integral
*=
alpha_sum_inv
;
// Write dqy
const
float
alpha_sum_inv_sq
=
alpha_sum_inv
*
alpha_sum_inv
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
dqy
[
i
*
BDIM_X
]
=
__vscale
(
alpha_sum_inv_sq
,
__vsub
(
__vscale
(
alpha_sum
,
loc_kvw
[
i
]),
__vmul
(
loc_vw_
[
i
],
loc_k__
[
i
])));
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
dqy
[
NLOC_M1
*
BDIM_X
]
=
__vscale
(
alpha_sum_inv_sq
,
__vsub
(
__vscale
(
alpha_sum
,
loc_kvw
[
NLOC_M1
]),
__vmul
(
loc_vw_
[
NLOC_M1
],
loc_k__
[
NLOC_M1
])));
}
// accumulate gradients for k and v
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotk_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
FLOATV_T
gdotv_v
=
__vset
<
FLOATV_T
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
i
*
BDIM_X
],
_kx
[
i
*
BDIM_X
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
i
*
BDIM_X
],
_vx
[
i
*
BDIM_X
]));
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
qdotk_v
=
__vadd
(
qdotk_v
,
__vmul
(
sh_qy
[
NLOC_M1
*
BDIM_X
],
_kx
[
NLOC_M1
*
BDIM_X
]));
gdotv_v
=
__vadd
(
gdotv_v
,
__vmul
(
sh_dy
[
NLOC_M1
*
BDIM_X
],
_vx
[
NLOC_M1
*
BDIM_X
]));
}
float
qdotk
=
__vred
(
qdotk_v
);
float
gdotv
=
__vred
(
gdotv_v
);
if
constexpr
(
BDIM_X
==
32
)
{
qdotk
=
__warp_sum
(
qdotk
);
gdotv
=
__warp_sum
(
gdotv
);
}
else
{
qdotk
=
__block_sum
<
BDIM_X
>
(
qdotk
);
gdotv
=
__block_sum
<
BDIM_X
>
(
gdotv
);
}
const
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
FLOATV_T
*
_dkx
=
dkx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
*
_dvx
=
dvx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
float
alpha_mul
=
alpha_inz
*
alpha_sum_inv
;
const
float
scale_fact_qy
=
(
gdotv
-
integral
)
*
alpha_mul
;
const
float
scale_fact_dy
=
alpha_mul
;
// float4, 128-bit atomics are only supported by devices of compute
// capability 9.x+, so on older devices we resort to 32-bit atomics
#if __CUDA_ARCH__ < 900
// making the loop count known at compile time doesn't seem
// to make any difference here so let's keep this (much)
// simpler version
float
*
sh_qy_scl
=
reinterpret_cast
<
float
*>
(
sh_qy
-
tidx
);
float
*
sh_dy_scl
=
reinterpret_cast
<
float
*>
(
sh_dy
-
tidx
);
float
*
_dkx_scl
=
reinterpret_cast
<
float
*>
(
_dkx
-
tidx
);
float
*
_dvx_scl
=
reinterpret_cast
<
float
*>
(
_dvx
-
tidx
);
constexpr
int
VEC_SIZE
=
sizeof
(
FLOATV_T
)
/
sizeof
(
float
);
// 32-bit, consecutive atomics to glmem
// strided atomics results in a severe slowdown
for
(
int
chan
=
tidx
;
chan
<
nchan
*
VEC_SIZE
;
chan
+=
BDIM_X
)
{
atomicAdd
(
_dkx_scl
+
chan
,
scale_fact_qy
*
sh_qy_scl
[
chan
]);
atomicAdd
(
_dvx_scl
+
chan
,
scale_fact_dy
*
sh_dy_scl
[
chan
]);
}
#else
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC_M1
;
i
++
)
{
atomicAdd
(
_dkx
+
i
*
BDIM_X
,
__vscale
(
scale_fact_qy
,
sh_qy
[
i
*
BDIM_X
]));
atomicAdd
(
_dvx
+
i
*
BDIM_X
,
__vscale
(
scale_fact_dy
,
sh_dy
[
i
*
BDIM_X
]));
}
if
(
NLOC_M1
*
BDIM_X
+
tidx
<
nchan
)
{
atomicAdd
(
_dkx
+
NLOC_M1
*
BDIM_X
,
__vscale
(
scale_fact_qy
,
sh_qy
[
NLOC_M1
*
BDIM_X
]));
atomicAdd
(
_dvx
+
NLOC_M1
*
BDIM_X
,
__vscale
(
scale_fact_dy
,
sh_dy
[
NLOC_M1
*
BDIM_X
]));
}
#endif
}
return
;
}
template
<
typename
FLOATV_T
>
void
launch_gen_attn_bwd
(
int
batch_size
,
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
_kxp
,
FLOATV_T
*
_vxp
,
FLOATV_T
*
_qyp
,
FLOATV_T
*
_dyp
,
int32_t
*
_row_idx
,
int64_t
*
_row_off
,
int64_t
*
_col_idx
,
float
*
_quad_weights
,
FLOATV_T
*
_dkxp
,
FLOATV_T
*
_dvxp
,
FLOATV_T
*
_dqyp
,
cudaStream_t
stream
)
{
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shsize
=
sizeof
(
FLOATV_T
)
*
nchans
*
5
*
block
.
y
;
// 5 arrays per warp
s2_attn_bwd_generic_vec_k
<
THREADS
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_generic_vec_k"
);
return
;
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
int
CUR_LOC_SIZE
,
int
MAX_LOC_SIZE
,
// max size of FLOATV_T[] local array
typename
FLOATV_T
>
void
launch_spc_attn_bwd
(
int
batch_size
,
int
nloc
,
// "BDIM_X*nloc" >= nchans
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
_kxp
,
FLOATV_T
*
_vxp
,
FLOATV_T
*
_qyp
,
FLOATV_T
*
_dyp
,
int32_t
*
_row_idx
,
int64_t
*
_row_off
,
int64_t
*
_col_idx
,
float
*
_quad_weights
,
FLOATV_T
*
_dkxp
,
FLOATV_T
*
_dvxp
,
FLOATV_T
*
_dqyp
,
cudaStream_t
stream
)
{
if
(
CUR_LOC_SIZE
==
nloc
)
{
dim3
block
(
BDIM_X
,
BDIM_Y
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shsize
=
sizeof
(
FLOATV_T
)
*
nchans
*
2
*
block
.
y
;
// 2 arrays per cta, block.y > 1 iif block.x==32
s2_attn_bwd_special_vec_k
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_special_vec_k"
);
return
;
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
launch_spc_attn_bwd
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
+
1
,
MAX_LOC_SIZE
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
}
return
;
}
static
void
s2_attn_bwd_dispatch
(
int
batch_size
,
int
nchans
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
at
::
Tensor
kxP
,
at
::
Tensor
vxP
,
at
::
Tensor
qyP
,
at
::
Tensor
dyP
,
at
::
Tensor
row_off
,
at
::
Tensor
col_idx
,
at
::
Tensor
quad_weights
,
at
::
Tensor
dkxP
,
at
::
Tensor
dvxP
,
at
::
Tensor
dqyP
)
{
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
const
int
nlat_in
=
kxP
.
size
(
1
);
// smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans
int
bdimx
;
bdimx
=
DIV_UP
(
nchans
,
MAX_LOCAL_ARR_LEN
);
bdimx
=
max
(
bdimx
,
WARP_SIZE
);
bdimx
=
next_pow2
(
bdimx
);
float
*
_kxp
=
reinterpret_cast
<
float
*>
(
kxP
.
data_ptr
());
float
*
_vxp
=
reinterpret_cast
<
float
*>
(
vxP
.
data_ptr
());
float
*
_qyp
=
reinterpret_cast
<
float
*>
(
qyP
.
data_ptr
());
float
*
_dyp
=
reinterpret_cast
<
float
*>
(
dyP
.
data_ptr
());
float
*
_dkxp
=
reinterpret_cast
<
float
*>
(
dkxP
.
data_ptr
());
float
*
_dvxp
=
reinterpret_cast
<
float
*>
(
dvxP
.
data_ptr
());
float
*
_dqyp
=
reinterpret_cast
<
float
*>
(
dqyP
.
data_ptr
());
int32_t
*
_row_idx
=
reinterpret_cast
<
int32_t
*>
(
row_idx
.
data_ptr
());
int64_t
*
_row_off
=
reinterpret_cast
<
int64_t
*>
(
row_off
.
data_ptr
());
int64_t
*
_col_idx
=
reinterpret_cast
<
int64_t
*>
(
col_idx
.
data_ptr
());
float
*
_quad_weights
=
reinterpret_cast
<
float
*>
(
quad_weights
.
data_ptr
());
constexpr
int
VEC_SIZE
=
sizeof
(
float4
)
/
sizeof
(
float
);
if
(
!
is_aligned
<
sizeof
(
float4
)
>
(
_kxp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_vxp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_qyp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_dyp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_dkxp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_dvxp
)
||
!
is_aligned
<
sizeof
(
float4
)
>
(
_dqyp
)
||
(
nchans
%
VEC_SIZE
)
!=
0
)
{
const
int
nloc
=
DIV_UP
(
nchans
,
bdimx
);
// to avoid the compilation of unused template instances;
// we use a block size BDIM_X that is the smallest power of 2
// such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans, so
// BDIM_X > 32 are used only for:
//
// (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans <= BDIM_X*MAX_LOCAL_ARR_LEN
constexpr
int
MIN_LOC_ARR_LEN
=
MAX_LOCAL_ARR_LEN
/
2
+
1
;
// use 2D blocks only if 32 threads are enough; w.r.t fowrard,
// we use the special kernel only up to BDIM_X=512 as with 1024
// each thread cannot use more than 64 registers, resulting in
// large amounts of registers spills
switch
(
bdimx
)
{
case
32
:
launch_spc_attn_bwd
<
32
,
2
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
case
64
:
launch_spc_attn_bwd
<
64
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
case
128
:
launch_spc_attn_bwd
<
128
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
case
256
:
launch_spc_attn_bwd
<
256
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
case
512
:
launch_spc_attn_bwd
<
512
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
default:
launch_gen_attn_bwd
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
,
stream
);
break
;
}
}
else
{
float4
*
_kxp4
=
reinterpret_cast
<
float4
*>
(
kxP
.
data_ptr
());
float4
*
_vxp4
=
reinterpret_cast
<
float4
*>
(
vxP
.
data_ptr
());
float4
*
_qyp4
=
reinterpret_cast
<
float4
*>
(
qyP
.
data_ptr
());
float4
*
_dyp4
=
reinterpret_cast
<
float4
*>
(
dyP
.
data_ptr
());
float4
*
_dkxp4
=
reinterpret_cast
<
float4
*>
(
dkxP
.
data_ptr
());
float4
*
_dvxp4
=
reinterpret_cast
<
float4
*>
(
dvxP
.
data_ptr
());
float4
*
_dqyp4
=
reinterpret_cast
<
float4
*>
(
dqyP
.
data_ptr
());
nchans
/=
VEC_SIZE
;
const
int
nloc
=
DIV_UP
(
nchans
,
bdimx
);
constexpr
int
MAX_LOCAL_VEC_LEN
=
MAX_LOCAL_ARR_LEN
/
VEC_SIZE
;
constexpr
int
MIN_LOC_VEC_LEN
=
MAX_LOCAL_VEC_LEN
/
2
+
1
;
// use 2D blocks only if 32 threads are enough
switch
(
bdimx
)
{
case
32
:
launch_spc_attn_bwd
<
32
,
2
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
case
64
:
launch_spc_attn_bwd
<
64
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
case
128
:
launch_spc_attn_bwd
<
128
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
case
256
:
launch_spc_attn_bwd
<
256
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
case
512
:
launch_spc_attn_bwd
<
512
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
default:
launch_gen_attn_bwd
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_dyp4
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp4
,
_dvxp4
,
_dqyp4
,
stream
);
break
;
}
}
return
;
}
// END backward kernels and functions
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
...
...
@@ -223,15 +892,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_INPUT_TENSOR
(
kx
);
CHECK_CUDA_INPUT_TENSOR
(
vx
);
CHECK_CUDA_INPUT_TENSOR
(
qy
);
CHECK_CUDA_INPUT_TENSOR
(
dy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
(
);
const
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
// extract dtype
auto
kx_type
=
kx
.
dtype
();
...
...
@@ -239,84 +909,52 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto
qy_type
=
qy
.
dtype
();
auto
dy_type
=
dy
.
dtype
();
// exract memory format
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
);
// convert to channels-last
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
).
to
(
at
::
MemoryFormat
::
ChannelsLast
);
// create output arrays
auto
dydk
=
torch
::
zeros_like
(
qyP
);
auto
dydv
=
torch
::
zeros_like
(
qyP
);
auto
dydq
=
torch
::
zeros_like
(
qyP
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// convert back to original dtype
dydk
=
dydk
.
to
(
kx_type
);
dydv
=
dydv
.
to
(
vx_type
);
dydq
=
dydq
.
to
(
qy_type
);
// permute back to original layout
if
(
!
kx_is_channels_last
)
{
dydk
=
dydk
.
to
(
kx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydk
=
dydk
.
to
(
kx_type
);
}
if
(
!
vx_is_channels_last
)
{
dydv
=
dydv
.
to
(
vx_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydv
=
dydv
.
to
(
vx_type
);
}
if
(
!
qy_is_channels_last
)
{
dydq
=
dydq
.
to
(
qy_type
).
to
(
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydq
=
dydq
.
to
(
qy_type
);
}
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
torch
::
Tensor
kxP
=
kx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
vxP
=
vx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
qyP
=
qy
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
dyP
=
dy
.
to
(
torch
::
kFloat32
);
// exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool
kx_is_channels_last
=
kxP
.
strides
()[
1
]
==
1
;
bool
vx_is_channels_last
=
vxP
.
strides
()[
1
]
==
1
;
bool
qy_is_channels_last
=
qyP
.
strides
()[
1
]
==
1
;
bool
dy_is_channels_last
=
dyP
.
strides
()[
1
]
==
1
;
// transpose if required
if
(
!
kx_is_channels_last
)
{
kxP
=
permute_4D_to0231
(
kxP
);
}
if
(
!
vx_is_channels_last
)
{
vxP
=
permute_4D_to0231
(
vxP
);
}
if
(
!
qy_is_channels_last
)
{
qyP
=
permute_4D_to0231
(
qyP
);
}
if
(
!
dy_is_channels_last
)
{
dyP
=
permute_4D_to0231
(
dyP
);
}
torch
::
Tensor
dkxP
=
torch
::
zeros_like
(
kxP
);
torch
::
Tensor
dvxP
=
torch
::
zeros_like
(
vxP
);
torch
::
Tensor
dqyP
=
torch
::
zeros_like
(
qyP
);
s2_attn_bwd_dispatch
(
batch_size
,
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
,
vxP
,
qyP
,
dyP
,
psi_row_off
,
psi_col_idx
,
quad_weights
,
dkxP
,
dvxP
,
dqyP
);
torch
::
Tensor
dkx
=
dkxP
;
torch
::
Tensor
dvx
=
dvxP
;
torch
::
Tensor
dqy
=
dqyP
;
if
(
!
kx_is_channels_last
)
{
dkx
=
permute_4D_to0312
(
dkx
);
}
if
(
!
vx_is_channels_last
)
{
dvx
=
permute_4D_to0312
(
dvx
);
}
if
(
!
qy_is_channels_last
)
{
dqy
=
permute_4D_to0312
(
dqy
);
}
// convert precision back to starting
dkx
=
dkx
.
to
(
kx_type
);
dvx
=
dvx
.
to
(
vx_type
);
dqy
=
dqy
.
to
(
qy_type
);
return
std
::
make_tuple
(
dkx
,
dvx
,
dqy
);
// #endif
}
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
4aaff021
...
...
@@ -39,147 +39,20 @@
#include <cub/cub.cuh>
#include <limits>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#include "cudamacro.h"
#include "attention_utils.cuh"
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
#define THREADS (64)
#define MAX_LOCAL_ARR_LEN (16)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
// BEGIN - forward kernels and functions
template
<
typename
VAL_T
>
__device__
VAL_T
__warp_sum
(
VAL_T
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM_X=blockDim.x*blockDim.y)
template
<
int
BDIM_X
,
int
BDIM_Y
=
1
,
int
BDIM_Z
=
1
,
typename
VAL_T
>
__device__
VAL_T
__block_sum
(
VAL_T
val
)
{
const
int
NWARP
=
(
BDIM_X
*
BDIM_Y
*
BDIM_Z
)
/
WARP_SIZE
;
val
=
__warp_sum
(
val
);
if
constexpr
(
NWARP
>
1
)
{
int
tid
=
threadIdx
.
x
;
if
constexpr
(
BDIM_Y
>
1
)
{
tid
+=
threadIdx
.
y
*
BDIM_X
;
}
if
constexpr
(
BDIM_Z
>
1
)
{
tid
+=
threadIdx
.
z
*
BDIM_X
*
BDIM_Y
;
}
const
int
lid
=
tid
%
WARP_SIZE
;
const
int
wid
=
tid
/
WARP_SIZE
;
__shared__
VAL_T
sh
[
NWARP
];
if
(
lid
==
0
)
{
sh
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
(
lid
<
NWARP
)
?
sh
[
lid
]
:
0
;
val
=
__warp_sum
(
val
);
__syncwarp
();
if
(
!
lid
)
{
sh
[
0
]
=
val
;
}
}
__syncthreads
();
val
=
sh
[
0
];
__syncthreads
();
}
return
val
;
}
template
<
typename
FLOATV_T
>
__device__
FLOATV_T
__vset
(
float
x
)
{}
template
<
>
__device__
float
__forceinline__
__vset
<
float
>
(
float
x
)
{
return
x
;
}
__device__
float
__forceinline__
__vmul
(
float
a
,
float
b
)
{
return
a
*
b
;
}
__device__
float
__forceinline__
__vadd
(
float
a
,
float
b
)
{
return
a
+
b
;
}
__device__
float
__forceinline__
__vred
(
float
a
)
{
return
a
;
}
__device__
float
__forceinline__
__vscale
(
float
s
,
float
v
)
{
return
v
*
s
;
}
__device__
float
__forceinline__
__vdiv
(
float
s
,
float
v
)
{
return
v
/
s
;
}
template
<
>
__device__
float4
__forceinline__
__vset
<
float4
>
(
float
x
)
{
return
make_float4
(
x
,
x
,
x
,
x
);
}
__device__
float4
__forceinline__
__vmul
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
*
b
.
x
,
a
.
y
*
b
.
y
,
a
.
z
*
b
.
z
,
a
.
w
*
b
.
w
);
}
__device__
float4
__forceinline__
__vadd
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
);
}
__device__
float
__forceinline__
__vred
(
float4
a
)
{
return
a
.
x
+
a
.
y
+
a
.
z
+
a
.
w
;
}
__device__
float4
__forceinline__
__vscale
(
float
s
,
float4
v
)
{
return
make_float4
(
s
*
v
.
x
,
s
*
v
.
y
,
s
*
v
.
z
,
s
*
v
.
w
);
}
__device__
float4
__forceinline__
__vdiv
(
float
s
,
float4
v
)
{
return
make_float4
(
s
/
v
.
x
,
s
/
v
.
y
,
s
/
v
.
z
,
s
/
v
.
w
);;
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template
<
int
BDIM
,
typename
FLOATV_T
>
// either float or float4
__global__
__launch_bounds__
(
BDIM
)
__launch_bounds__
(
BDIM
_X
)
void
s2_attn_fwd_generic_vec_k
(
int
nchan
,
// no. of FLOATV_T elements along channel dim
int
nlat_in
,
int
nlon_in
,
...
...
@@ -188,10 +61,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const
FLOATV_T
*
__restrict__
kx
,
const
FLOATV_T
*
__restrict__
vx
,
const
FLOATV_T
*
__restrict__
qy
,
const
torch
::
PackedTensorAccessor32
<
int
,
1
,
torch
::
RestrictPtrTraits
>
row_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
row_off
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
col_idx
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
,
const
int32_t
*
__restrict__
row_idx
,
const
int64_t
*
__restrict__
row_off
,
const
int64_t
*
__restrict__
col_idx
,
const
float
*
__restrict__
quad_weights
,
FLOATV_T
*
__restrict__
y
)
{
extern
__shared__
__align__
(
sizeof
(
float4
))
float
shext
[];
...
...
@@ -225,11 +98,13 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const
int64_t
rbeg
=
row_off
[
ho
];
const
int64_t
rend
=
row_off
[
ho
+
1
];
col_idx
+=
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
rbeg
+
off
];
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
...
...
@@ -273,39 +148,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
return
;
}
template
<
typename
FLOATV_T
>
void
launch_gen_attn_kernel
(
int
batch_size
,
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
__restrict__
_kxp
,
FLOATV_T
*
__restrict__
_vxp
,
FLOATV_T
*
__restrict__
_qyp
,
at
::
Tensor
row_idx
,
at
::
Tensor
row_off
,
at
::
Tensor
col_idx
,
at
::
Tensor
quad_weights
,
FLOATV_T
*
__restrict__
_yp
,
cudaStream_t
stream
)
{
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shsize
=
sizeof
(
FLOATV_T
)
*
nchans
*
block
.
y
;
auto
_row_idx
=
row_idx
.
packed_accessor32
<
int
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_row_off
=
row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_col_idx
=
col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_quad_weights
=
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
();
s2_attn_fwd_generic_vec_k
<
THREADS
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
return
;
}
// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
template
<
int
BDIM_X
,
int
BDIM_Y
,
...
...
@@ -321,10 +163,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const
FLOATV_T
*
__restrict__
kx
,
const
FLOATV_T
*
__restrict__
vx
,
const
FLOATV_T
*
__restrict__
qy
,
const
torch
::
PackedTensorAccessor32
<
int
,
1
,
torch
::
RestrictPtrTraits
>
row_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
row_off
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
col_idx
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
,
const
int32_t
*
__restrict__
row_idx
,
const
int64_t
*
__restrict__
row_off
,
const
int64_t
*
__restrict__
col_idx
,
const
float
*
__restrict__
quad_weights
,
FLOATV_T
*
__restrict__
y
)
{
static_assert
(
0
==
(
BDIM_X
&
(
BDIM_X
-
1
)));
...
...
@@ -375,11 +217,13 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const
int64_t
rbeg
=
row_off
[
ho
];
const
int64_t
rend
=
row_off
[
ho
+
1
];
col_idx
+=
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
col_idx
[
rbeg
+
off
];
const
int64_t
col
=
col_idx
[
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
...
...
@@ -442,139 +286,84 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
return
;
}
template
<
typename
FLOATV_T
>
void
launch_gen_attn_fwd
(
int
batch_size
,
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
__restrict__
_kxp
,
FLOATV_T
*
__restrict__
_vxp
,
FLOATV_T
*
__restrict__
_qyp
,
int32_t
*
_row_idx
,
int64_t
*
_row_off
,
int64_t
*
_col_idx
,
float
*
_quad_weights
,
FLOATV_T
*
__restrict__
_yp
,
cudaStream_t
stream
)
{
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shsize
=
sizeof
(
FLOATV_T
)
*
nchans
*
block
.
y
;
s2_attn_fwd_generic_vec_k
<
THREADS
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_generic_vec_k"
);
return
;
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
int
CUR_LOC_SIZE
,
int
MAX_LOC_SIZE
,
// max size of FLOATV_T[] local array
typename
FLOATV_T
>
void
launch_spc_attn_
kernel
(
int
batch_size
,
int
nloc
,
// "BDIM_X*nloc" >= nchans
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
__restrict__
_kxp
,
FLOATV_T
*
__restrict__
_vxp
,
FLOATV_T
*
__restrict__
_qyp
,
at
::
Tensor
row_idx
,
at
::
Tensor
row_off
,
at
::
Tensor
col_idx
,
at
::
Tensor
quad_weights
,
FLOATV_T
*
__restrict__
_yp
,
cudaStream_t
stream
)
{
void
launch_spc_attn_
fwd
(
int
batch_size
,
int
nloc
,
// "BDIM_X*nloc" >= nchans
int
nchans
,
int
nlat_in
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
FLOATV_T
*
__restrict__
_kxp
,
FLOATV_T
*
__restrict__
_vxp
,
FLOATV_T
*
__restrict__
_qyp
,
int32_t
*
_
row_idx
,
int64_t
*
_
row_off
,
int64_t
*
_
col_idx
,
float
*
_
quad_weights
,
FLOATV_T
*
__restrict__
_yp
,
cudaStream_t
stream
)
{
if
(
CUR_LOC_SIZE
==
nloc
)
{
auto
_row_idx
=
row_idx
.
packed_accessor32
<
int
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_row_off
=
row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_col_idx
=
col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
();
auto
_quad_weights
=
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
();
dim3
block
(
BDIM_X
,
BDIM_Y
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
//printf("block: (%d, %d)\n", block.x, block.y);
//printf("grid: (%d, %d)\n", grid.x, grid.y);
size_t
shsize
=
sizeof
(
FLOATV_T
)
*
nchans
*
block
.
y
;
// block.y > 1 iif block.x==32
s2_attn_fwd_special_vec_k
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_special_vec_k"
);
return
;
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
launch_spc_attn_
kernel
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
+
1
,
MAX_LOC_SIZE
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
launch_spc_attn_
fwd
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
+
1
,
MAX_LOC_SIZE
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
}
return
;
}
__global__
void
set_rlen_rids_k
(
const
int
n
,
const
int64_t
*
__restrict__
offs
,
int
*
__restrict__
rids
,
int
*
__restrict__
rlen
)
{
const
int
nth
=
gridDim
.
x
*
blockDim
.
x
;
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
tid
;
i
<
n
;
i
+=
nth
)
{
rids
[
i
]
=
i
;
rlen
[
i
]
=
offs
[
i
+
1
]
-
offs
[
i
];
}
return
;
}
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
)
{
int64_t
*
_row_off_d
=
reinterpret_cast
<
int64_t
*>
(
row_off
.
data_ptr
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
row_off
.
device
());
torch
::
Tensor
rids_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_d
=
reinterpret_cast
<
int
*>
(
rids_d
.
data_ptr
());
int
*
_rlen_d
=
reinterpret_cast
<
int
*>
(
rlen_d
.
data_ptr
());
const
int
grid
=
DIV_UP
(
nlat_out
,
THREADS
);
const
int
block
=
THREADS
;
set_rlen_rids_k
<<<
grid
,
block
,
0
,
stream
>>>
(
nlat_out
,
_row_off_d
,
_rids_d
,
_rlen_d
);
torch
::
Tensor
rids_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_sort_d
=
reinterpret_cast
<
int
*>
(
rids_sort_d
.
data_ptr
());
int
*
_rlen_sort_d
=
reinterpret_cast
<
int
*>
(
rlen_sort_d
.
data_ptr
());
size_t
temp_storage_bytes
=
0
;
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
NULL
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
row_off
.
device
());
torch
::
Tensor
temp_storage_d
=
torch
::
empty
({
int64_t
(
temp_storage_bytes
)},
options
);
void
*
_temp_storage_d
=
reinterpret_cast
<
void
*>
(
temp_storage_d
.
data_ptr
());
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
_temp_storage_d
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
return
rids_sort_d
;
}
template
<
unsigned
int
ALIGN
>
int
is_aligned
(
const
void
*
ptr
)
{
static_assert
(
0
==
(
ALIGN
&
(
ALIGN
-
1
)));
return
(
0
==
(
uintptr_t
(
ptr
)
&
(
ALIGN
-
1
)));
}
static
unsigned
int
next_pow2
(
unsigned
int
x
)
{
x
-=
1
;
#pragma unroll
for
(
int
i
=
1
;
i
<=
sizeof
(
x
)
*
8
/
2
;
i
*=
2
)
{
x
|=
x
>>
i
;
}
return
x
+
1
;
}
static
void
s2_attention_dipatch
(
int
batch_size
,
static
void
s2_attn_fwd_dispatch
(
int
batch_size
,
int
nchans
,
int
nlon_in
,
int
nlat_out
,
...
...
@@ -585,11 +374,13 @@ static void s2_attention_dipatch(int batch_size,
at
::
Tensor
row_off
,
at
::
Tensor
col_idx
,
at
::
Tensor
quad_weights
,
at
::
Tensor
yP
,
cudaStream_t
stream
)
{
at
::
Tensor
yP
)
{
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
...
...
@@ -607,6 +398,11 @@ static void s2_attention_dipatch(int batch_size,
float
*
_qyp
=
reinterpret_cast
<
float
*>
(
qyP
.
data_ptr
());
float
*
_yp
=
reinterpret_cast
<
float
*>
(
yP
.
data_ptr
());
int32_t
*
_row_idx
=
reinterpret_cast
<
int32_t
*>
(
row_idx
.
data_ptr
());
int64_t
*
_row_off
=
reinterpret_cast
<
int64_t
*>
(
row_off
.
data_ptr
());
int64_t
*
_col_idx
=
reinterpret_cast
<
int64_t
*>
(
col_idx
.
data_ptr
());
float
*
_quad_weights
=
reinterpret_cast
<
float
*>
(
quad_weights
.
data_ptr
());
constexpr
int
VEC_SIZE
=
sizeof
(
float4
)
/
sizeof
(
float
);
if
(
!
is_aligned
<
sizeof
(
float4
)
>
(
_kxp
)
||
...
...
@@ -616,16 +412,24 @@ static void s2_attention_dipatch(int batch_size,
(
nchans
%
VEC_SIZE
)
!=
0
)
{
const
int
nloc
=
DIV_UP
(
nchans
,
bdimx
);
// to avoid the compilation of unused template instances;
// we use a block size BDIM_X that is the smallest power of 2
// such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans, so
// BDIM_X > 32 are used only for:
//
// (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans <= BDIM_X*MAX_LOCAL_ARR_LEN
constexpr
int
MIN_LOC_ARR_LEN
=
MAX_LOCAL_ARR_LEN
/
2
+
1
;
// use 2D blocks only if 32 threads are enough
switch
(
bdimx
)
{
case
32
:
launch_spc_attn_
kernel
<
32
,
2
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
64
:
launch_spc_attn_
kernel
<
64
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
128
:
launch_spc_attn_
kernel
<
128
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
256
:
launch_spc_attn_
kernel
<
256
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
512
:
launch_spc_attn_
kernel
<
512
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
1024
:
launch_spc_attn_
kernel
<
1024
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
default:
launch_gen_attn_
kernel
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
32
:
launch_spc_attn_
fwd
<
32
,
2
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
case
64
:
launch_spc_attn_
fwd
<
64
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
case
128
:
launch_spc_attn_
fwd
<
128
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
case
256
:
launch_spc_attn_
fwd
<
256
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
case
512
:
launch_spc_attn_
fwd
<
512
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
case
1024
:
launch_spc_attn_
fwd
<
1024
,
1
,
MIN_LOC_ARR_LEN
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
default:
launch_gen_attn_
fwd
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp
,
stream
);
break
;
}
}
else
{
...
...
@@ -638,233 +442,26 @@ static void s2_attention_dipatch(int batch_size,
nchans
/=
VEC_SIZE
;
const
int
nloc
=
DIV_UP
(
nchans
,
bdimx
);
static
constexpr
int
MAX_LOCAL_VEC_LEN
=
MAX_LOCAL_ARR_LEN
/
VEC_SIZE
;
constexpr
int
MAX_LOCAL_VEC_LEN
=
MAX_LOCAL_ARR_LEN
/
VEC_SIZE
;
constexpr
int
MIN_LOC_VEC_LEN
=
MAX_LOCAL_VEC_LEN
/
2
+
1
;
// use 2D blocks only if 32 threads are enough
switch
(
bdimx
)
{
case
32
:
launch_spc_attn_
kernel
<
32
,
2
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
64
:
launch_spc_attn_
kernel
<
64
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
128
:
launch_spc_attn_
kernel
<
128
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
256
:
launch_spc_attn_
kernel
<
256
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
512
:
launch_spc_attn_
kernel
<
512
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
1024
:
launch_spc_attn_
kernel
<
1024
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
default:
launch_gen_attn_
kernel
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
32
:
launch_spc_attn_
fwd
<
32
,
2
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
case
64
:
launch_spc_attn_
fwd
<
64
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
case
128
:
launch_spc_attn_
fwd
<
128
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
case
256
:
launch_spc_attn_
fwd
<
256
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
case
512
:
launch_spc_attn_
fwd
<
512
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
case
1024
:
launch_spc_attn_
fwd
<
1024
,
1
,
MIN_LOC_VEC_LEN
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
default:
launch_gen_attn_
fwd
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
_
row_idx
,
_
row_off
,
_
col_idx
,
_
quad_weights
,
_yp4
,
stream
);
break
;
}
}
return
;
}
// END - forward kernels and functions
// BEGIN - tensor permutation kernels and functions
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0231_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
src
,
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
coff
=
blockIdx
.
x
*
BDIM_X
;
// channel offset
const
int
woff
=
blockIdx
.
y
*
BDIM_X
;
// width offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
coff
+
j
+
tidy
<
nchn
)
?
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
:
0.
f
;
}
}
__syncthreads
();
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
woff
+
j
+
tidy
<
nlon
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
}
}
return
;
}
__global__
void
empty_k
()
{}
static
int
getPtxver
()
{
cudaFuncAttributes
attrs
;
CHECK_CUDA
(
cudaFuncGetAttributes
(
&
attrs
,
empty_k
));
return
attrs
.
ptxVersion
*
10
;
}
static
at
::
Tensor
permute_4D_floatT_to0231
(
at
::
Tensor
src
,
cudaStream_t
stream
)
{
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
1
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
2
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
block
.
y
=
TRANSP_WARPS_X_TILE_GENERIC
;
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_GENERIC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
}
return
dst
;
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0312_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
src
,
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
woff
=
blockIdx
.
x
*
BDIM_X
;
// width offset
const
int
coff
=
blockIdx
.
y
*
BDIM_X
;
// channel offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
woff
+
j
+
tidy
<
nlon
)
?
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
:
0.
f
;
}
}
__syncthreads
();
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
coff
+
j
+
tidy
<
nchn
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];;
}
}
}
}
return
;
}
static
at
::
Tensor
permute_4D_floatT_to0312
(
at
::
Tensor
src
,
cudaStream_t
stream
)
{
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
2
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
1
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
block
.
y
=
TRANSP_WARPS_X_TILE_GENERIC
;
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_GENERIC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
}
return
dst
;
}
// END - tensor permutation kernels and functions
// END - forward kernels and functions
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
...
...
@@ -875,36 +472,37 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_
INPUT_
TENSOR
(
kx
);
CHECK_CUDA_
INPUT_
TENSOR
(
vx
);
CHECK_CUDA_
INPUT_
TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
// TODO: check sizes
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
torch
::
Tensor
kxP
=
kx
;
torch
::
Tensor
vxP
=
vx
;
torch
::
Tensor
qyP
=
qy
;
// extract dtype
auto
qy_type
=
qy
.
dtype
();
torch
::
Tensor
kxP
=
kx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
vxP
=
vx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
qyP
=
qy
.
to
(
torch
::
kFloat32
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
// these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool
kx_is_channels_last
=
kxP
.
strides
()[
1
]
==
1
;
bool
vx_is_channels_last
=
vxP
.
strides
()[
1
]
==
1
;
bool
qy_is_channels_last
=
qyP
.
strides
()[
1
]
==
1
;
if
(
!
k_channel
_fir
st
)
{
kxP
=
permute_4D_
floatT_
to0231
(
kx
,
stream
);
}
if
(
!
v_channel
_fir
st
)
{
vxP
=
permute_4D_
floatT_
to0231
(
vx
,
stream
);
}
if
(
!
q_channel
_fir
st
)
{
qyP
=
permute_4D_
floatT_
to0231
(
qy
,
stream
);
}
if
(
!
k
x_is
_channel
s_la
st
)
{
kxP
=
permute_4D_to0231
(
kx
P
);
}
if
(
!
v
x_is
_channel
s_la
st
)
{
vxP
=
permute_4D_to0231
(
vx
P
);
}
if
(
!
q
y_is
_channel
s_la
st
)
{
qyP
=
permute_4D_to0231
(
qy
P
);
}
torch
::
Tensor
yP
=
torch
::
empty_like
(
qyP
);
s2_att
ention
_dipatch
(
batch_size
,
s2_att
n_fwd
_di
s
patch
(
batch_size
,
uo_num_channels
,
nlon_in
,
nlat_out
,
...
...
@@ -913,11 +511,13 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
psi_row_off
,
psi_col_idx
,
quad_weights
,
yP
,
// out tensor
stream
);
yP
);
torch
::
Tensor
y
=
yP
;
if
(
!
q_channel_first
)
{
y
=
permute_4D_floatT_to0312
(
yP
,
stream
);
}
if
(
!
qy_is_channels_last
)
{
y
=
permute_4D_to0312
(
y
);
}
// convert precision back to starting
y
=
y
.
to
(
qy_type
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
...
...
torch_harmonics/csrc/attention/attention_utils.cu
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <limits>
#include "cudamacro.h"
#include "attention_utils.cuh"
#define THREADS (64)
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
// BEGIN - CSR rows sorting kernels and functions
__global__
void
set_rlen_rids_k
(
const
int
n
,
const
int64_t
*
__restrict__
offs
,
int
*
__restrict__
rids
,
int
*
__restrict__
rlen
)
{
const
int
nth
=
gridDim
.
x
*
blockDim
.
x
;
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
tid
;
i
<
n
;
i
+=
nth
)
{
rids
[
i
]
=
i
;
rlen
[
i
]
=
offs
[
i
+
1
]
-
offs
[
i
];
}
return
;
}
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
)
{
int64_t
*
_row_off_d
=
reinterpret_cast
<
int64_t
*>
(
row_off
.
data_ptr
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
row_off
.
device
());
torch
::
Tensor
rids_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_d
=
reinterpret_cast
<
int
*>
(
rids_d
.
data_ptr
());
int
*
_rlen_d
=
reinterpret_cast
<
int
*>
(
rlen_d
.
data_ptr
());
const
int
grid
=
DIV_UP
(
nlat_out
,
THREADS
);
const
int
block
=
THREADS
;
set_rlen_rids_k
<<<
grid
,
block
,
0
,
stream
>>>
(
nlat_out
,
_row_off_d
,
_rids_d
,
_rlen_d
);
torch
::
Tensor
rids_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_sort_d
=
reinterpret_cast
<
int
*>
(
rids_sort_d
.
data_ptr
());
int
*
_rlen_sort_d
=
reinterpret_cast
<
int
*>
(
rlen_sort_d
.
data_ptr
());
size_t
temp_storage_bytes
=
0
;
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
NULL
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
row_off
.
device
());
torch
::
Tensor
temp_storage_d
=
torch
::
empty
({
int64_t
(
temp_storage_bytes
)},
options
);
void
*
_temp_storage_d
=
reinterpret_cast
<
void
*>
(
temp_storage_d
.
data_ptr
());
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
_temp_storage_d
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
return
rids_sort_d
;
}
// END - CSR rows sorting kernels and functions
// BEGIN - 4D tensor permutation kernels and functions
__global__
void
empty_k
()
{}
static
int
getPtxver
()
{
cudaFuncAttributes
attrs
;
CHECK_CUDA
(
cudaFuncGetAttributes
(
&
attrs
,
empty_k
));
return
attrs
.
ptxVersion
*
10
;
}
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_generic"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_sm100"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
return
dst
;
}
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_generic"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_sm100"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
return
dst
;
}
// END - tensor permutation kernels and functions
// BEGIN - general host-side functions
unsigned
int
next_pow2
(
unsigned
int
x
)
{
x
-=
1
;
#pragma unroll
for
(
int
i
=
1
;
i
<=
sizeof
(
x
)
*
8
/
2
;
i
*=
2
)
{
x
|=
x
>>
i
;
}
return
x
+
1
;
}
// END - general host-side functions
torch_harmonics/csrc/attention/attention_utils.cuh
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <ATen/ATen.h>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
// CSR rows sorting kernels and functions
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
);
// 4D tensor permutation kernels and functions
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
);
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
);
// Host tensor dump and CSR manipulation functions
void
dump_tensor
(
const
char
*
fname
,
at
::
Tensor
t
);
void
dump_csr
(
const
char
*
fname
,
at
::
Tensor
roff
,
at
::
Tensor
cols
);
int
part_csr_rows
(
int
*
row_perm
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
,
int
**
part_off
,
int
**
part_val
);
int
verify_part
(
const
int
npart
,
const
int
*
part_off
,
const
int
*
part_val
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
);
void
verify_part_new
(
const
int
nlon_out
,
const
int
nlat_in
,
const
int
nlon_in
,
const
int
npart
,
// partitioning data
const
int
*
part_off
,
const
int
*
part_val
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
);
unsigned
int
next_pow2
(
unsigned
int
x
);
// utility host functions and templates
template
<
unsigned
int
ALIGN
>
int
is_aligned
(
const
void
*
ptr
)
{
static_assert
(
0
==
(
ALIGN
&
(
ALIGN
-
1
)));
return
(
0
==
(
uintptr_t
(
ptr
)
&
(
ALIGN
-
1
)));
}
// utility device functions and templates
template
<
typename
FLOATV_T
>
__device__
FLOATV_T
__vset
(
float
x
)
{
static_assert
(
sizeof
(
FLOATV_T
)
==
0
,
"Unsupported type for __vset"
);
return
FLOATV_T
{};
}
template
<
>
__device__
float
__forceinline__
__vset
<
float
>
(
float
x
)
{
return
x
;
}
__device__
float
__forceinline__
__vmul
(
float
a
,
float
b
)
{
return
a
*
b
;
}
__device__
float
__forceinline__
__vadd
(
float
a
,
float
b
)
{
return
a
+
b
;
}
__device__
float
__forceinline__
__vsub
(
float
a
,
float
b
)
{
return
a
-
b
;
}
__device__
float
__forceinline__
__vred
(
float
a
)
{
return
a
;
}
__device__
float
__forceinline__
__vscale
(
float
s
,
float
v
)
{
return
v
*
s
;
}
__device__
float
__forceinline__
__vdiv
(
float
s
,
float
v
)
{
return
v
/
s
;
}
template
<
>
__device__
float4
__forceinline__
__vset
<
float4
>
(
float
x
)
{
return
make_float4
(
x
,
x
,
x
,
x
);
}
__device__
float4
__forceinline__
__vmul
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
*
b
.
x
,
a
.
y
*
b
.
y
,
a
.
z
*
b
.
z
,
a
.
w
*
b
.
w
);
}
__device__
float4
__forceinline__
__vadd
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
);
}
__device__
float4
__forceinline__
__vsub
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
-
b
.
x
,
a
.
y
-
b
.
y
,
a
.
z
-
b
.
z
,
a
.
w
-
b
.
w
);
}
__device__
float
__forceinline__
__vred
(
float4
a
)
{
return
a
.
x
+
a
.
y
+
a
.
z
+
a
.
w
;
}
__device__
float4
__forceinline__
__vscale
(
float
s
,
float4
v
)
{
return
make_float4
(
s
*
v
.
x
,
s
*
v
.
y
,
s
*
v
.
z
,
s
*
v
.
w
);
}
__device__
float4
__forceinline__
__vdiv
(
float
s
,
float4
v
)
{
return
make_float4
(
s
/
v
.
x
,
s
/
v
.
y
,
s
/
v
.
z
,
s
/
v
.
w
);;
}
template
<
typename
VAL_T
>
__device__
VAL_T
__warp_sum
(
VAL_T
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
,
WARP_SIZE
);
}
return
val
;
}
template
<
int
BDIM_X
,
int
BDIM_Y
=
1
,
int
BDIM_Z
=
1
,
typename
VAL_T
>
__device__
VAL_T
__block_sum
(
VAL_T
val
)
{
const
int
NWARP
=
(
BDIM_X
*
BDIM_Y
*
BDIM_Z
)
/
WARP_SIZE
;
val
=
__warp_sum
(
val
);
if
constexpr
(
NWARP
>
1
)
{
int
tid
=
threadIdx
.
x
;
if
constexpr
(
BDIM_Y
>
1
)
{
tid
+=
threadIdx
.
y
*
BDIM_X
;
}
if
constexpr
(
BDIM_Z
>
1
)
{
tid
+=
threadIdx
.
z
*
BDIM_X
*
BDIM_Y
;
}
const
int
lid
=
tid
%
WARP_SIZE
;
const
int
wid
=
tid
/
WARP_SIZE
;
__shared__
VAL_T
sh
[
NWARP
];
if
(
lid
==
0
)
{
sh
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
(
lid
<
NWARP
)
?
sh
[
lid
]
:
0
;
val
=
__warp_sum
(
val
);
__syncwarp
();
if
(
!
lid
)
{
sh
[
0
]
=
val
;
}
}
__syncthreads
();
val
=
sh
[
0
];
__syncthreads
();
}
return
val
;
}
// transpose utils
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0231_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
coff
=
blockIdx
.
x
*
BDIM_X
;
// channel offset
const
int
woff
=
blockIdx
.
y
*
BDIM_X
;
// width offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
coff
+
j
+
tidy
<
nchn
)
?
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
woff
+
j
+
tidy
<
nlon
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
}
}
return
;
}
template
<
int
WARPS_X_TILE
,
typename
VAL_T
>
void
launch_permute_to0231
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
WARPS_X_TILE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
1
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
2
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0231_k
<
WARP_SIZE
,
WARPS_X_TILE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0312_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
woff
=
blockIdx
.
x
*
BDIM_X
;
// width offset
const
int
coff
=
blockIdx
.
y
*
BDIM_X
;
// channel offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
woff
+
j
+
tidy
<
nlon
)
?
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
coff
+
j
+
tidy
<
nchn
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];;
}
}
}
}
return
;
}
template
<
int
WARPS_X_TILE
,
typename
VAL_T
>
void
launch_permute_to0312
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
WARPS_X_TILE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
2
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
1
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0312_k
<
WARP_SIZE
,
WARPS_X_TILE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
torch_harmonics/csrc/attention/cudamacro.h
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
}}
torch_harmonics/examples/losses.py
View file @
4aaff021
...
...
@@ -31,6 +31,7 @@
import
torch
import
torch.nn
as
nn
import
torch.amp
as
amp
import
torch.nn.functional
as
F
from
typing
import
Optional
from
abc
import
ABC
,
abstractmethod
...
...
@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
self
.
register_buffer
(
"k_theta_mesh"
,
k_theta_mesh
)
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
prd_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
prd_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
tar_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
tar_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
prdtype
=
prd
.
dtype
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
prd
=
prd
.
to
(
torch
.
float32
)
prd_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
prd_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
tar_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
tar_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
# Return the element-wise loss term
return
torch
.
abs
(
prd_prime_fft2_phi_h
-
tar_prime_fft2_phi_h
)
+
torch
.
abs
(
prd_prime_fft2_theta_h
-
tar_prime_fft2_theta_h
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment