Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
4aaff021
Unverified
Commit
4aaff021
authored
Jul 17, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 17, 2025
Browse files
Merge pull request #91 from NVIDIA/maurob/devel
Attention Backward improvement
parents
ab44ba59
fa58767d
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1518 additions
and
656 deletions
+1518
-656
setup.py
setup.py
+2
-0
tests/test_attention.py
tests/test_attention.py
+2
-3
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+15
-0
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+5
-1
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+753
-115
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+132
-532
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+180
-0
torch_harmonics/csrc/attention/attention_utils.cuh
torch_harmonics/csrc/attention/attention_utils.cuh
+373
-0
torch_harmonics/csrc/attention/cudamacro.h
torch_harmonics/csrc/attention/cudamacro.h
+47
-0
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+9
-5
No files found.
setup.py
View file @
4aaff021
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags
=
[]
nvcc_extra_flags
=
[]
if
profile_mode
:
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-Xptxas=-v"
)
if
debug_mode
:
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
...
@@ -102,6 +103,7 @@ def get_ext_modules():
...
@@ -102,6 +103,7 @@ def get_ext_modules():
CUDAExtension
(
CUDAExtension
(
name
=
"attention_cuda_extension"
,
name
=
"attention_cuda_extension"
,
sources
=
[
sources
=
[
"torch_harmonics/csrc/attention/attention_utils.cu"
,
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu"
,
"torch_harmonics/csrc/attention/attention_interface.cu"
,
"torch_harmonics/csrc/attention/attention_interface.cu"
,
...
...
tests/test_attention.py
View file @
4aaff021
...
@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
...
@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
2
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
2
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
1
,
1
,
(
2
,
4
),
(
2
,
4
),
"equiangular"
,
"equiangular"
,
1e-5
,
1e-3
],
[
4
,
4
,
4
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-5
,
1e-3
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-5
,
1e-3
],
],
],
skip_on_empty
=
True
,
skip_on_empty
=
True
,
...
@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
...
@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
[
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"equiangular"
,
"equiangular"
,
1e-2
,
0
],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"legendre-gauss"
,
"legendre-gauss"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-2
,
0
],
[
4
,
4
,
1
,
(
6
,
12
),
(
6
,
12
),
"lobatto"
,
"lobatto"
,
1e-2
,
0
],
],
],
...
...
torch_harmonics/_neighborhood_attention.py
View file @
4aaff021
...
@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B
,
_
,
H
,
W
=
grad_output
.
shape
B
,
_
,
H
,
W
=
grad_output
.
shape
grad_output
=
grad_output
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
grad_output
=
grad_output
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
# save type and convert to float32
kw_dtype
=
kw
.
dtype
vw_dtype
=
vw
.
dtype
qw_dtype
=
qw
.
dtype
kw
=
kw
.
to
(
torch
.
float32
).
contiguous
()
vw
=
vw
.
to
(
torch
.
float32
).
contiguous
()
qw
=
qw
.
to
(
torch
.
float32
).
contiguous
()
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
dkw
,
dvw
,
dqw
=
attention_cuda_extension
.
backward_dkvq
(
kw
,
vw
,
qw
,
grad_output
,
dkw
,
dvw
,
dqw
=
attention_cuda_extension
.
backward_dkvq
(
kw
,
vw
,
qw
,
grad_output
,
quad_weights
,
quad_weights
,
col_idx
,
row_off
,
col_idx
,
row_off
,
...
@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_
,
C
,
H
,
W
=
dqw
.
shape
_
,
C
,
H
,
W
=
dqw
.
shape
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
# convert precision
dkw
=
dkw
.
to
(
dtype
=
kw_dtype
)
dvw
=
dvw
.
to
(
dtype
=
vw_dtype
)
dqw
=
dqw
.
to
(
dtype
=
qw_dtype
)
# input grads
# input grads
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
...
...
torch_harmonics/csrc/attention/attention.cuh
View file @
4aaff021
...
@@ -34,7 +34,11 @@
...
@@ -34,7 +34,11 @@
#include <cstdint>
#include <cstdint>
#include <torch/torch.h>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
4aaff021
This diff is collapsed.
Click to expand it.
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
4aaff021
This diff is collapsed.
Click to expand it.
torch_harmonics/csrc/attention/attention_utils.cu
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <limits>
#include "cudamacro.h"
#include "attention_utils.cuh"
#define THREADS (64)
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
// BEGIN - CSR rows sorting kernels and functions
__global__
void
set_rlen_rids_k
(
const
int
n
,
const
int64_t
*
__restrict__
offs
,
int
*
__restrict__
rids
,
int
*
__restrict__
rlen
)
{
const
int
nth
=
gridDim
.
x
*
blockDim
.
x
;
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
tid
;
i
<
n
;
i
+=
nth
)
{
rids
[
i
]
=
i
;
rlen
[
i
]
=
offs
[
i
+
1
]
-
offs
[
i
];
}
return
;
}
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
)
{
int64_t
*
_row_off_d
=
reinterpret_cast
<
int64_t
*>
(
row_off
.
data_ptr
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
row_off
.
device
());
torch
::
Tensor
rids_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_d
=
reinterpret_cast
<
int
*>
(
rids_d
.
data_ptr
());
int
*
_rlen_d
=
reinterpret_cast
<
int
*>
(
rlen_d
.
data_ptr
());
const
int
grid
=
DIV_UP
(
nlat_out
,
THREADS
);
const
int
block
=
THREADS
;
set_rlen_rids_k
<<<
grid
,
block
,
0
,
stream
>>>
(
nlat_out
,
_row_off_d
,
_rids_d
,
_rlen_d
);
torch
::
Tensor
rids_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
torch
::
Tensor
rlen_sort_d
=
torch
::
empty
({
nlat_out
},
options
);
int
*
_rids_sort_d
=
reinterpret_cast
<
int
*>
(
rids_sort_d
.
data_ptr
());
int
*
_rlen_sort_d
=
reinterpret_cast
<
int
*>
(
rlen_sort_d
.
data_ptr
());
size_t
temp_storage_bytes
=
0
;
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
NULL
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
row_off
.
device
());
torch
::
Tensor
temp_storage_d
=
torch
::
empty
({
int64_t
(
temp_storage_bytes
)},
options
);
void
*
_temp_storage_d
=
reinterpret_cast
<
void
*>
(
temp_storage_d
.
data_ptr
());
CHECK_CUDA
(
cub
::
DeviceRadixSort
::
SortPairsDescending
(
_temp_storage_d
,
temp_storage_bytes
,
_rlen_d
,
_rlen_sort_d
,
_rids_d
,
_rids_sort_d
,
nlat_out
,
0
,
sizeof
(
*
_rlen_d
)
*
8
,
stream
));
return
rids_sort_d
;
}
// END - CSR rows sorting kernels and functions
// BEGIN - 4D tensor permutation kernels and functions
__global__
void
empty_k
()
{}
static
int
getPtxver
()
{
cudaFuncAttributes
attrs
;
CHECK_CUDA
(
cudaFuncGetAttributes
(
&
attrs
,
empty_k
));
return
attrs
.
ptxVersion
*
10
;
}
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_generic"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_sm100"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
return
dst
;
}
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
)
{
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
const
int
ptxv
=
getPtxver
();
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_generic"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_sm100"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
return
dst
;
}
// END - tensor permutation kernels and functions
// BEGIN - general host-side functions
unsigned
int
next_pow2
(
unsigned
int
x
)
{
x
-=
1
;
#pragma unroll
for
(
int
i
=
1
;
i
<=
sizeof
(
x
)
*
8
/
2
;
i
*=
2
)
{
x
|=
x
>>
i
;
}
return
x
+
1
;
}
// END - general host-side functions
torch_harmonics/csrc/attention/attention_utils.cuh
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <ATen/ATen.h>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
// CSR rows sorting kernels and functions
at
::
Tensor
sortRows
(
int
nlat_out
,
at
::
Tensor
row_off
,
cudaStream_t
stream
);
// 4D tensor permutation kernels and functions
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
);
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
);
// Host tensor dump and CSR manipulation functions
void
dump_tensor
(
const
char
*
fname
,
at
::
Tensor
t
);
void
dump_csr
(
const
char
*
fname
,
at
::
Tensor
roff
,
at
::
Tensor
cols
);
int
part_csr_rows
(
int
*
row_perm
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
,
int
**
part_off
,
int
**
part_val
);
int
verify_part
(
const
int
npart
,
const
int
*
part_off
,
const
int
*
part_val
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
);
void
verify_part_new
(
const
int
nlon_out
,
const
int
nlat_in
,
const
int
nlon_in
,
const
int
npart
,
// partitioning data
const
int
*
part_off
,
const
int
*
part_val
,
const
at
::
Tensor
roff
,
const
at
::
Tensor
cols
);
unsigned
int
next_pow2
(
unsigned
int
x
);
// utility host functions and templates
template
<
unsigned
int
ALIGN
>
int
is_aligned
(
const
void
*
ptr
)
{
static_assert
(
0
==
(
ALIGN
&
(
ALIGN
-
1
)));
return
(
0
==
(
uintptr_t
(
ptr
)
&
(
ALIGN
-
1
)));
}
// utility device functions and templates
template
<
typename
FLOATV_T
>
__device__
FLOATV_T
__vset
(
float
x
)
{
static_assert
(
sizeof
(
FLOATV_T
)
==
0
,
"Unsupported type for __vset"
);
return
FLOATV_T
{};
}
template
<
>
__device__
float
__forceinline__
__vset
<
float
>
(
float
x
)
{
return
x
;
}
__device__
float
__forceinline__
__vmul
(
float
a
,
float
b
)
{
return
a
*
b
;
}
__device__
float
__forceinline__
__vadd
(
float
a
,
float
b
)
{
return
a
+
b
;
}
__device__
float
__forceinline__
__vsub
(
float
a
,
float
b
)
{
return
a
-
b
;
}
__device__
float
__forceinline__
__vred
(
float
a
)
{
return
a
;
}
__device__
float
__forceinline__
__vscale
(
float
s
,
float
v
)
{
return
v
*
s
;
}
__device__
float
__forceinline__
__vdiv
(
float
s
,
float
v
)
{
return
v
/
s
;
}
template
<
>
__device__
float4
__forceinline__
__vset
<
float4
>
(
float
x
)
{
return
make_float4
(
x
,
x
,
x
,
x
);
}
__device__
float4
__forceinline__
__vmul
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
*
b
.
x
,
a
.
y
*
b
.
y
,
a
.
z
*
b
.
z
,
a
.
w
*
b
.
w
);
}
__device__
float4
__forceinline__
__vadd
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
);
}
__device__
float4
__forceinline__
__vsub
(
float4
a
,
float4
b
)
{
return
make_float4
(
a
.
x
-
b
.
x
,
a
.
y
-
b
.
y
,
a
.
z
-
b
.
z
,
a
.
w
-
b
.
w
);
}
__device__
float
__forceinline__
__vred
(
float4
a
)
{
return
a
.
x
+
a
.
y
+
a
.
z
+
a
.
w
;
}
__device__
float4
__forceinline__
__vscale
(
float
s
,
float4
v
)
{
return
make_float4
(
s
*
v
.
x
,
s
*
v
.
y
,
s
*
v
.
z
,
s
*
v
.
w
);
}
__device__
float4
__forceinline__
__vdiv
(
float
s
,
float4
v
)
{
return
make_float4
(
s
/
v
.
x
,
s
/
v
.
y
,
s
/
v
.
z
,
s
/
v
.
w
);;
}
template
<
typename
VAL_T
>
__device__
VAL_T
__warp_sum
(
VAL_T
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
,
WARP_SIZE
);
}
return
val
;
}
template
<
int
BDIM_X
,
int
BDIM_Y
=
1
,
int
BDIM_Z
=
1
,
typename
VAL_T
>
__device__
VAL_T
__block_sum
(
VAL_T
val
)
{
const
int
NWARP
=
(
BDIM_X
*
BDIM_Y
*
BDIM_Z
)
/
WARP_SIZE
;
val
=
__warp_sum
(
val
);
if
constexpr
(
NWARP
>
1
)
{
int
tid
=
threadIdx
.
x
;
if
constexpr
(
BDIM_Y
>
1
)
{
tid
+=
threadIdx
.
y
*
BDIM_X
;
}
if
constexpr
(
BDIM_Z
>
1
)
{
tid
+=
threadIdx
.
z
*
BDIM_X
*
BDIM_Y
;
}
const
int
lid
=
tid
%
WARP_SIZE
;
const
int
wid
=
tid
/
WARP_SIZE
;
__shared__
VAL_T
sh
[
NWARP
];
if
(
lid
==
0
)
{
sh
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
(
lid
<
NWARP
)
?
sh
[
lid
]
:
0
;
val
=
__warp_sum
(
val
);
__syncwarp
();
if
(
!
lid
)
{
sh
[
0
]
=
val
;
}
}
__syncthreads
();
val
=
sh
[
0
];
__syncthreads
();
}
return
val
;
}
// transpose utils
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0231_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
coff
=
blockIdx
.
x
*
BDIM_X
;
// channel offset
const
int
woff
=
blockIdx
.
y
*
BDIM_X
;
// width offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
coff
+
j
+
tidy
<
nchn
)
?
src
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
woff
+
j
+
tidy
<
nlon
)
{
dst
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
}
}
return
;
}
template
<
int
WARPS_X_TILE
,
typename
VAL_T
>
void
launch_permute_to0231
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
WARPS_X_TILE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
1
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
2
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0231_k
<
WARP_SIZE
,
WARPS_X_TILE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
1
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
template
<
int
BDIM_X
,
int
BDIM_Y
,
typename
VAL_T
>
__global__
__launch_bounds__
(
BDIM_X
*
BDIM_Y
)
void
permute_to0312_k
(
const
int
nchn
,
const
int
nlat
,
const
int
nlon
,
const
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
src
,
at
::
PackedTensorAccessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
dst
)
{
static_assert
(
!
(
BDIM_X
&
(
BDIM_X
-
1
)));
static_assert
(
!
(
BDIM_Y
&
(
BDIM_Y
-
1
)));
static_assert
(
BDIM_X
>=
BDIM_Y
);
__shared__
VAL_T
sh
[
BDIM_X
][
BDIM_X
+
1
];
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
woff
=
blockIdx
.
x
*
BDIM_X
;
// width offset
const
int
coff
=
blockIdx
.
y
*
BDIM_X
;
// channel offset
const
int
batch
=
blockIdx
.
z
/
nlat
;
// batch (same for all block)
const
int
h
=
blockIdx
.
z
-
(
batch
*
nlat
);
// height (same for all block)
const
int
nchn_full
=
(
nchn
-
coff
)
>=
BDIM_X
;
const
int
nlon_full
=
(
nlon
-
woff
)
>=
BDIM_X
;
if
(
nchn_full
&&
nlon_full
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
];
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];
}
}
else
{
if
(
coff
+
tidx
<
nchn
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
sh
[
j
+
tidy
][
tidx
]
=
(
woff
+
j
+
tidy
<
nlon
)
?
src
[
batch
][
h
][
woff
+
j
+
tidy
][
coff
+
tidx
]
:
VAL_T
(
0
);
}
}
__syncthreads
();
if
(
woff
+
tidx
<
nlon
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
BDIM_X
;
j
+=
BDIM_Y
)
{
if
(
coff
+
j
+
tidy
<
nchn
)
{
dst
[
batch
][
coff
+
j
+
tidy
][
h
][
woff
+
tidx
]
=
sh
[
tidx
][
j
+
tidy
];;
}
}
}
}
return
;
}
template
<
int
WARPS_X_TILE
,
typename
VAL_T
>
void
launch_permute_to0312
(
at
::
Tensor
src
,
at
::
Tensor
dst
){
dim3
block
;
dim3
grid
;
block
.
x
=
WARP_SIZE
;
block
.
y
=
WARPS_X_TILE
;
grid
.
x
=
DIV_UP
(
src
.
size
(
2
),
block
.
x
);
grid
.
y
=
DIV_UP
(
src
.
size
(
3
),
block
.
x
);
grid
.
z
=
src
.
size
(
1
)
*
src
.
size
(
0
);
assert
(
grid
.
y
<
65536
);
assert
(
grid
.
z
<
65536
);
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
permute_to0312_k
<
WARP_SIZE
,
WARPS_X_TILE
>
<<<
grid
,
block
,
0
,
stream
>>>
(
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
),
src
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
VAL_T
,
4
,
at
::
RestrictPtrTraits
>
());
}
torch_harmonics/csrc/attention/cudamacro.h
0 → 100644
View file @
4aaff021
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
}}
torch_harmonics/examples/losses.py
View file @
4aaff021
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.amp
as
amp
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
typing
import
Optional
from
typing
import
Optional
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
...
@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
self
.
register_buffer
(
"k_theta_mesh"
,
k_theta_mesh
)
self
.
register_buffer
(
"k_theta_mesh"
,
k_theta_mesh
)
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
prd_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
prdtype
=
prd
.
dtype
prd_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
prd
=
prd
.
to
(
torch
.
float32
)
tar_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
prd_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
tar_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
prd_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
prd
)).
real
tar_prime_fft2_phi_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_phi_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
tar_prime_fft2_theta_h
=
torch
.
fft
.
ifft2
(
1j
*
self
.
k_theta_mesh
*
torch
.
fft
.
fft2
(
tar
)).
real
# Return the element-wise loss term
# Return the element-wise loss term
return
torch
.
abs
(
prd_prime_fft2_phi_h
-
tar_prime_fft2_phi_h
)
+
torch
.
abs
(
prd_prime_fft2_theta_h
-
tar_prime_fft2_theta_h
)
return
torch
.
abs
(
prd_prime_fft2_phi_h
-
tar_prime_fft2_phi_h
)
+
torch
.
abs
(
prd_prime_fft2_theta_h
-
tar_prime_fft2_theta_h
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment