Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
45fc2a46
Commit
45fc2a46
authored
Jul 16, 2025
by
Thorsten Kurth
Browse files
cleanup with contiguous checks
parent
51200bda
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
406 additions
and
307 deletions
+406
-307
tests/test_attention.py
tests/test_attention.py
+2
-3
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+15
-0
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+5
-1
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+125
-112
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+26
-21
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+60
-168
torch_harmonics/csrc/attention/attention_utils.cuh
torch_harmonics/csrc/attention/attention_utils.cuh
+173
-2
No files found.
tests/test_attention.py
View file @
45fc2a46
...
@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
...
@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
2
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
2
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
1
,
1
,
(
2
,
4
),
(
2
,
4
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-5
,
1e-3
],
],
],
skip_on_empty
=
True
,
skip_on_empty
=
True
,
...
@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
...
@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
[
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-2
,
0
],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-2
,
0
],
],
],
...
...
torch_harmonics/_neighborhood_attention.py
View file @
45fc2a46
...
@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B
,
_
,
H
,
W
=
grad_output
.
shape
B
,
_
,
H
,
W
=
grad_output
.
shape
grad_output
=
grad_output
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
grad_output
=
grad_output
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
# save type and convert to float32
kw_dtype
=
kw
.
dtype
vw_dtype
=
vw
.
dtype
qw_dtype
=
qw
.
dtype
kw
=
kw
.
to
(
torch
.
float32
).
contiguous
()
vw
=
vw
.
to
(
torch
.
float32
).
contiguous
()
qw
=
qw
.
to
(
torch
.
float32
).
contiguous
()
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
dkw
,
dvw
,
dqw
=
attention_cuda_extension
.
backward_dkvq
(
kw
,
vw
,
qw
,
grad_output
,
dkw
,
dvw
,
dqw
=
attention_cuda_extension
.
backward_dkvq
(
kw
,
vw
,
qw
,
grad_output
,
quad_weights
,
quad_weights
,
col_idx
,
row_off
,
col_idx
,
row_off
,
...
@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_
,
C
,
H
,
W
=
dqw
.
shape
_
,
C
,
H
,
W
=
dqw
.
shape
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
# convert precision
dkw
=
dkw
.
to
(
dtype
=
kw_dtype
)
dvw
=
dvw
.
to
(
dtype
=
vw_dtype
)
dqw
=
dqw
.
to
(
dtype
=
qw_dtype
)
# input grads
# input grads
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
...
...
torch_harmonics/csrc/attention/attention.cuh
View file @
45fc2a46
...
@@ -34,7 +34,11 @@
...
@@ -34,7 +34,11 @@
#include <cstdint>
#include <cstdint>
#include <torch/torch.h>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
45fc2a46
...
@@ -785,11 +785,13 @@ static void s2_attn_bwd_dispatch(int batch_size,
...
@@ -785,11 +785,13 @@ static void s2_attn_bwd_dispatch(int batch_size,
at
::
Tensor
quad_weights
,
at
::
Tensor
quad_weights
,
at
::
Tensor
dkxP
,
at
::
Tensor
dkxP
,
at
::
Tensor
dvxP
,
at
::
Tensor
dvxP
,
at
::
Tensor
dqyP
,
at
::
Tensor
dqyP
)
{
cudaStream_t
stream
)
{
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// sort row indices (ho-s) in descending order
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
// based on (row_off[ho+1]-row_off[ho])
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
...
@@ -890,122 +892,129 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -890,122 +892,129 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_INPUT_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_INPUT_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_INPUT_TENSOR
(
qy
);
CHECK_CUDA_INPUT_TENSOR
(
dy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// #if 0
#if 0
// // extract dtype
// auto kx_type = kx.dtype();
// auto vx_type = vx.dtype();
// auto qy_type = qy.dtype();
// auto dy_type = dy.dtype();
// // exract memory format
// auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// // convert to channels-last
// auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// // create output arrays
// auto dydk = torch::zeros_like(qyP);
// auto dydv = torch::zeros_like(qyP);
// auto dydq = torch::zeros_like(qyP);
// size_t uo_num_channels = kx.size(1);
// const int batch_size = kx.size(0);
// dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
// dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
// size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
// cudaEvent_t start, stop;
// float milliseconds = 0;
// CHECK_CUDA(cudaEventCreate(&start));
// CHECK_CUDA(cudaEventCreate(&stop));
// CHECK_CUDA(cudaEventRecord(start, stream));
// s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
// uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
// CHECK_CUDA(cudaEventRecord(stop, stream));
// CHECK_CUDA(cudaEventSynchronize(stop));
// CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// CHECK_CUDA(cudaEventDestroy(start));
// CHECK_CUDA(cudaEventDestroy(stop));
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// // Permute outputs back to memory layout given by input. if input had channels
// // first, leave it in that layout, otherwise permute layout back to [batch,
// // channel, ho, wo]
// // convert back to original dtype
// dydk = dydk.to(kx_type);
// dydv = dydv.to(vx_type);
// dydq = dydq.to(qy_type);
// // permute back to original layout
// if (!kx_is_channels_last) {
// dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydk = dydk.to(kx_type);
// }
// if (!vx_is_channels_last) {
// dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydv = dydv.to(vx_type);
// }
// if (!qy_is_channels_last) {
// dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydq = dydq.to(qy_type);
// }
// return std::make_tuple(dydk, dydv, dydq);
// #else
const
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
// extract dtype
// extract dtype
auto
kx_type
=
kx
.
dtype
();
auto
kx_type
=
kx
.
dtype
();
auto
vx_type
=
vx
.
dtype
();
auto
vx_type
=
vx
.
dtype
();
auto
qy_type
=
qy
.
dtype
();
auto
qy_type
=
qy
.
dtype
();
auto
dy_type
=
dy
.
dtype
();
auto
dy_type
=
dy
.
dtype
();
// exract memory format
torch
::
Tensor
kxP
=
kx
.
to
(
torch
::
kFloat32
);
auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
torch
::
Tensor
vxP
=
vx
.
to
(
torch
::
kFloat32
);
auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
torch
::
Tensor
qyP
=
qy
.
to
(
torch
::
kFloat32
);
auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
torch
::
Tensor
dyP
=
dy
.
to
(
torch
::
kFloat32
);
auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// convert to channels-last
auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// create output arrays
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// convert back to original dtype
dydk = dydk.to(kx_type);
dydv = dydv.to(vx_type);
dydq = dydq.to(qy_type);
// permute back to original layout
if (!kx_is_channels_last) {
dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
} else {
dydk = dydk.to(kx_type);
}
if (!vx_is_channels_last) {
dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
} else {
dydv = dydv.to(vx_type);
}
if (!qy_is_channels_last) {
dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
} else {
dydq = dydq.to(qy_type);
}
return std::make_tuple(dydk, dydv, dydq);
#else
const
size_t
uo_num_channels
=
kx
.
size
(
1
);
// exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
const
int
batch_size
=
kx
.
size
(
0
);
// the former fails for num_channels == 1
bool
kx_is_channels_last
=
kxP
.
strides
()[
1
]
==
1
;
torch
::
Tensor
kxP
=
kx
;
bool
vx_is_channels_last
=
vxP
.
strides
()[
1
]
==
1
;
torch
::
Tensor
vxP
=
vx
;
bool
qy_is_channels_last
=
qyP
.
strides
()[
1
]
==
1
;
torch
::
Tensor
qyP
=
qy
;
bool
dy_is_channels_last
=
dyP
.
strides
()[
1
]
==
1
;
torch
::
Tensor
dyP
=
dy
;
auto
kx_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
vx_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
qy_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
if
(
!
kx_channel_first
)
{
kxP
=
permute_4D_floatT_to0231
(
kx
,
stream
);
}
// transpose if required
if
(
!
vx
_channel
_fir
st
)
{
v
xP
=
permute_4D_
floatT_to0231
(
vx
,
stream
);
}
if
(
!
kx_is
_channel
s_la
st
)
{
k
xP
=
permute_4D_
to0231
(
kxP
);
}
if
(
!
qy
_channel
_fir
st
)
{
qy
P
=
permute_4D_
floatT_to0231
(
qy
,
stream
);
}
if
(
!
vx_is
_channel
s_la
st
)
{
vx
P
=
permute_4D_
to0231
(
vxP
);
}
if
(
!
qy_is_channels_last
)
{
qyP
=
permute_4D_to0231
(
qyP
);
}
if
(
!
dy_channel
_fir
st
)
{
dyP
=
permute_4D_
floatT_
to0231
(
dy
,
stream
);
}
if
(
!
dy_
is_
channel
s_la
st
)
{
dyP
=
permute_4D_to0231
(
dy
P
);
}
torch
::
Tensor
dkxP
=
torch
::
zeros_like
(
kxP
);
torch
::
Tensor
dkxP
=
torch
::
zeros_like
(
kxP
);
torch
::
Tensor
dvxP
=
torch
::
zeros_like
(
vxP
);
torch
::
Tensor
dvxP
=
torch
::
zeros_like
(
vxP
);
...
@@ -1020,17 +1029,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -1020,17 +1029,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
psi_row_off
,
psi_row_off
,
psi_col_idx
,
psi_col_idx
,
quad_weights
,
quad_weights
,
dkxP
,
dvxP
,
dqyP
,
// out tensors
dkxP
,
dvxP
,
dqyP
);
stream
);
torch
::
Tensor
dkx
=
dkxP
;
torch
::
Tensor
dkx
=
dkxP
;
torch
::
Tensor
dvx
=
dvxP
;
torch
::
Tensor
dvx
=
dvxP
;
torch
::
Tensor
dqy
=
dqyP
;
torch
::
Tensor
dqy
=
dqyP
;
if
(
!
kx_channel_first
)
{
dkx
=
permute_4D_floatT_to0312
(
dkxP
,
stream
);
}
if
(
!
kx_is_channels_last
)
{
dkx
=
permute_4D_to0312
(
dkx
);
}
if
(
!
vx_channel_first
)
{
dvx
=
permute_4D_floatT_to0312
(
dvxP
,
stream
);
}
if
(
!
vx_is_channels_last
)
{
dvx
=
permute_4D_to0312
(
dvx
);
}
if
(
!
qy_channel_first
)
{
dqy
=
permute_4D_floatT_to0312
(
dqyP
,
stream
);
}
if
(
!
qy_is_channels_last
)
{
dqy
=
permute_4D_to0312
(
dqy
);
}
// convert precision back to starting
dkx
=
dkx
.
to
(
kx_type
);
dvx
=
dvx
.
to
(
vx_type
);
dqy
=
dqy
.
to
(
qy_type
);
return
std
::
make_tuple
(
dkx
,
dvx
,
dqy
);
return
std
::
make_tuple
(
dkx
,
dvx
,
dqy
);
#endif
//
#endif
}
}
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
45fc2a46
...
@@ -374,11 +374,13 @@ static void s2_attn_fwd_dispatch(int batch_size,
...
@@ -374,11 +374,13 @@ static void s2_attn_fwd_dispatch(int batch_size,
at
::
Tensor
row_off
,
at
::
Tensor
row_off
,
at
::
Tensor
col_idx
,
at
::
Tensor
col_idx
,
at
::
Tensor
quad_weights
,
at
::
Tensor
quad_weights
,
at
::
Tensor
yP
,
at
::
Tensor
yP
)
{
cudaStream_t
stream
)
{
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
static_assert
(
0
==
(
MAX_LOCAL_ARR_LEN
&
(
MAX_LOCAL_ARR_LEN
-
1
)));
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// sort row indices (ho-s) in descending order
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
// based on (row_off[ho+1]-row_off[ho])
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
at
::
Tensor
row_idx
=
sortRows
(
nlat_out
,
row_off
,
stream
);
...
@@ -470,32 +472,33 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -470,32 +472,33 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
int
nlon_in
,
int
nlon_in
,
int
nlat_out
,
int
nlat_out
,
int
nlon_out
)
{
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_
INPUT_
TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_
INPUT_
TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_
INPUT_
TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
// TODO: check sizes
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
const
int
batch_size
=
kx
.
size
(
0
);
torch
::
Tensor
kxP
=
kx
;
// extract dtype
torch
::
Tensor
vxP
=
vx
;
auto
qy_type
=
qy
.
dtype
();
torch
::
Tensor
qyP
=
qy
;
torch
::
Tensor
kxP
=
kx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
vxP
=
vx
.
to
(
torch
::
kFloat32
);
torch
::
Tensor
qyP
=
qy
.
to
(
torch
::
kFloat32
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
// these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
// the former fails for num_channels == 1
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
bool
kx_is_channels_last
=
kxP
.
strides
()[
1
]
==
1
;
bool
vx_is_channels_last
=
vxP
.
strides
()[
1
]
==
1
;
bool
qy_is_channels_last
=
qyP
.
strides
()[
1
]
==
1
;
if
(
!
k_channel
_fir
st
)
{
kxP
=
permute_4D_
floatT_
to0231
(
kx
,
stream
);
}
if
(
!
k
x_is
_channel
s_la
st
)
{
kxP
=
permute_4D_to0231
(
kx
P
);
}
if
(
!
v_channel
_fir
st
)
{
vxP
=
permute_4D_
floatT_
to0231
(
vx
,
stream
);
}
if
(
!
v
x_is
_channel
s_la
st
)
{
vxP
=
permute_4D_to0231
(
vx
P
);
}
if
(
!
q_channel
_fir
st
)
{
qyP
=
permute_4D_
floatT_
to0231
(
qy
,
stream
);
}
if
(
!
q
y_is
_channel
s_la
st
)
{
qyP
=
permute_4D_to0231
(
qy
P
);
}
torch
::
Tensor
yP
=
torch
::
empty_like
(
qyP
);
torch
::
Tensor
yP
=
torch
::
empty_like
(
qyP
);
...
@@ -508,11 +511,13 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -508,11 +511,13 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
psi_row_off
,
psi_row_off
,
psi_col_idx
,
psi_col_idx
,
quad_weights
,
quad_weights
,
yP
,
// out tensor
yP
);
stream
);
torch
::
Tensor
y
=
yP
;
torch
::
Tensor
y
=
yP
;
if
(
!
q_channel_first
)
{
y
=
permute_4D_floatT_to0312
(
yP
,
stream
);
}
if
(
!
qy_is_channels_last
)
{
y
=
permute_4D_to0312
(
y
);
}
// convert precision back to starting
y
=
y
.
to
(
qy_type
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
...
...
torch_harmonics/csrc/attention/attention_utils.cu
View file @
45fc2a46
...
@@ -111,66 +111,6 @@ at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
...
@@ -111,66 +111,6 @@ at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
// BEGIN - 4D tensor permutation kernels and functions
// BEGIN - 4D tensor permutation kernels and functions
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0231_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
src
,
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
coff
=
blockIdx
.
x
*
BDIM_X
;
// channel offset
const
int
woff
=
blockIdx
.
y
*
BDIM_X
;
// width offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
coff
+
j
+
tidy
<
nchn
)
?
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
:
0.
f
;
}
}
__syncthreads
();
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
woff
+
j
+
tidy
<
nlon
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
}
}
return
;
}
__global__
void
empty_k
()
{}
__global__
void
empty_k
()
{}
static
int
getPtxver
()
{
static
int
getPtxver
()
{
...
@@ -179,144 +119,96 @@ static int getPtxver() {
...
@@ -179,144 +119,96 @@ static int getPtxver() {
return
attrs
.
ptxVersion
*
10
;
return
attrs
.
ptxVersion
*
10
;
}
}
at
::
Tensor
permute_4D_
floatT_
to0231
(
at
::
Tensor
src
,
cudaStream_t
stream
)
{
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
)
{
dim3
block
;
//
dim3 block;
dim3
grid
;
//
dim3 grid;
block
.
x
=
WARP_SIZE
;
//
block.x = WARP_SIZE;
grid
.
x
=
DIV_UP
(
src
.
size
(
1
),
block
.
x
);
//
grid.x = DIV_UP(src.size(1), block.x);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
//
grid.y = DIV_UP(src.size(3), block.x);
grid
.
z
=
src
.
size
(
2
)
*
src
.
size
(
0
);
//
grid.z = src.size(2)*src.size(0);
assert
(
grid
.
y
<
65536
);
//
assert(grid.y < 65536);
assert
(
grid
.
z
<
65536
);
//
assert(grid.z < 65536);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
src
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()
).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
const
int
ptxv
=
getPtxver
();
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
if
(
ptxv
<
100
)
{
block
.
y
=
TRANSP_WARPS_X_TILE_GENERIC
;
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_generic"
,
([
&
]
{
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_GENERIC
>
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
}));
src
.
size
(
2
),
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
src
.
size
(
3
),
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
// <<<grid, block, 0, stream>>>(src.size(1),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_sm100"
,
([
&
]
{
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
}));
src
.
size
(
2
),
//block.y = TRANSP_WARPS_X_TILE_SM100;
src
.
size
(
3
),
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
// <<<grid, block, 0, stream>>>(src.size(1),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
}
return
dst
;
return
dst
;
}
}
template
<
int
BDIM_X
,
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
)
{
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0312_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
src
,
torch
::
PackedTensorAccessor32
<
VAL_T
,
4
,
torch
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
woff
=
blockIdx
.
x
*
BDIM_X
;
// width offset
const
int
coff
=
blockIdx
.
y
*
BDIM_X
;
// channel offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
woff
+
j
+
tidy
<
nlon
)
?
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
:
0.
f
;
}
}
__syncthreads
();
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
coff
+
j
+
tidy
<
nchn
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];;
}
}
}
}
return
;
}
at
::
Tensor
permute_4D_floatT_to0312
(
at
::
Tensor
src
,
cudaStream_t
stream
)
{
dim3
block
;
//
dim3 block;
dim3
grid
;
//
dim3 grid;
block
.
x
=
WARP_SIZE
;
//
block.x = WARP_SIZE;
grid
.
x
=
DIV_UP
(
src
.
size
(
2
),
block
.
x
);
//
grid.x = DIV_UP(src.size(2), block.x);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
//
grid.y = DIV_UP(src.size(3), block.x);
grid
.
z
=
src
.
size
(
1
)
*
src
.
size
(
0
);
//
grid.z = src.size(1)*src.size(0);
assert
(
grid
.
y
<
65536
);
//
assert(grid.y < 65536);
assert
(
grid
.
z
<
65536
);
//
assert(grid.z < 65536);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
src
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()
).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
const
int
ptxv
=
getPtxver
();
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
if
(
ptxv
<
100
)
{
block
.
y
=
TRANSP_WARPS_X_TILE_GENERIC
;
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_GENERIC
>
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_generic"
,
([
&
]
{
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
src
.
size
(
1
),
}));
src
.
size
(
2
),
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
// <<<grid, block, 0, stream>>>(src.size(3),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_sm100"
,
([
&
]
{
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
}));
src
.
size
(
1
),
//block.y = TRANSP_WARPS_X_TILE_SM100;
src
.
size
(
2
),
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
// <<<grid, block, 0, stream>>>(src.size(3),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
}
...
...
torch_harmonics/csrc/attention/attention_utils.cuh
View file @
45fc2a46
...
@@ -40,8 +40,8 @@
...
@@ -40,8 +40,8 @@
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
);
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
);
// 4D tensor permutation kernels and functions
// 4D tensor permutation kernels and functions
at
::
Tensor
permute_4D_
floatT_
to0231
(
at
::
Tensor
src
,
cudaStream_t
stream
);
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
);
at
::
Tensor
permute_4D_
floatT_
to0312
(
at
::
Tensor
src
,
cudaStream_t
stream
);
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
);
// Host tensor dump and CSR manipulation functions
// Host tensor dump and CSR manipulation functions
void
dump_tensor
(
const
char
*
fname
,
at
::
Tensor
t
);
void
dump_tensor
(
const
char
*
fname
,
at
::
Tensor
t
);
...
@@ -200,3 +200,174 @@ __device__ VAL_T __block_sum(VAL_T val) {
...
@@ -200,3 +200,174 @@ __device__ VAL_T __block_sum(VAL_T val) {
}
}
return
val
;
return
val
;
}
}
// transpose utils
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0231_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
coff
=
blockIdx
.
x
*
BDIM_X
;
// channel offset
const
int
woff
=
blockIdx
.
y
*
BDIM_X
;
// width offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
coff
+
j
+
tidy
<
nchn
)
?
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
woff
+
j
+
tidy
<
nlon
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
}
}
return
;
}
template
<
int
TRANSP_WARPS_X_TILE_SIZE
,
typename
VAL_T
>
void
launch_permute_to0231
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
TRANSP_WARPS_X_TILE_SIZE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
1
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
2
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SIZE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0312_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
woff
=
blockIdx
.
x
*
BDIM_X
;
// width offset
const
int
coff
=
blockIdx
.
y
*
BDIM_X
;
// channel offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
woff
+
j
+
tidy
<
nlon
)
?
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
coff
+
j
+
tidy
<
nchn
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];;
}
}
}
}
return
;
}
template
<
int
TRANSP_WARPS_X_TILE_SIZE
,
typename
VAL_T
>
void
launch_permute_to0312
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
TRANSP_WARPS_X_TILE_SIZE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
2
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
1
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SIZE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment