Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
ae8257b5
Commit
ae8257b5
authored
May 22, 2025
by
Boris Bonev
Committed by
Boris Bonev
May 24, 2025
Browse files
adapting header files
parent
13d6130e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
21 deletions
+51
-21
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+19
-19
torch_harmonics/attention.py
torch_harmonics/attention.py
+1
-1
torch_harmonics/csrc/attention/attention_row_offset.cu
torch_harmonics/csrc/attention/attention_row_offset.cu
+30
-0
torch_harmonics/plotting.py
torch_harmonics/plotting.py
+1
-1
No files found.
torch_harmonics/_neighborhood_attention.py
View file @
ae8257b5
# coding=utf-8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 202
4
The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 202
5
The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# Redistribution and use in source and binary forms, with or without
...
@@ -115,7 +115,7 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
...
@@ -115,7 +115,7 @@ def _neighborhood_attention_s2_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor,
# dvx: B, C, Hi, Wi
# dvx: B, C, Hi, Wi
dvx
=
torch
.
zeros_like
(
vx
)
dvx
=
torch
.
zeros_like
(
vx
)
for
ho
in
range
(
nlat_out
):
for
ho
in
range
(
nlat_out
):
# get number of nonzeros
# get number of nonzeros
...
@@ -181,7 +181,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
...
@@ -181,7 +181,7 @@ def _neighborhood_attention_s2_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor,
# quad_weights: Hi
# quad_weights: Hi
# output
# output
# dkx: B, C, Hi, Wi
# dkx: B, C, Hi, Wi
dkx
=
torch
.
zeros_like
(
kx
)
dkx
=
torch
.
zeros_like
(
kx
)
for
ho
in
range
(
nlat_out
):
for
ho
in
range
(
nlat_out
):
...
@@ -262,15 +262,15 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
...
@@ -262,15 +262,15 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
# quad_weights: Hi
# quad_weights: Hi
# output
# output
# dvx: B, C, Hi, Wi
# dvx: B, C, Hi, Wi
dqy
=
torch
.
zeros_like
(
qy
)
dqy
=
torch
.
zeros_like
(
qy
)
for
ho
in
range
(
nlat_out
):
for
ho
in
range
(
nlat_out
):
# get number of nonzeros
# get number of nonzeros
zstart
=
row_off
[
ho
]
zstart
=
row_off
[
ho
]
zend
=
row_off
[
ho
+
1
]
zend
=
row_off
[
ho
+
1
]
for
wo
in
range
(
nlon_out
):
for
wo
in
range
(
nlon_out
):
alpha
=
torch
.
zeros
((
dy
.
shape
[
0
],
zend
-
zstart
),
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
alpha
=
torch
.
zeros
((
dy
.
shape
[
0
],
zend
-
zstart
),
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
...
@@ -353,7 +353,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
...
@@ -353,7 +353,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
kw
=
kw
.
to
(
torch
.
float32
)
kw
=
kw
.
to
(
torch
.
float32
)
vw
=
vw
.
to
(
torch
.
float32
)
vw
=
vw
.
to
(
torch
.
float32
)
qw
=
qw
.
to
(
torch
.
float32
)
qw
=
qw
.
to
(
torch
.
float32
)
output
=
_neighborhood_attention_s2_fwd_torch
(
kw
,
vw
,
qw
,
quad_weights
,
output
=
_neighborhood_attention_s2_fwd_torch
(
kw
,
vw
,
qw
,
quad_weights
,
col_idx
,
row_off
,
col_idx
,
row_off
,
nlon_in
,
nlat_out
,
nlon_out
)
nlon_in
,
nlat_out
,
nlon_out
)
...
@@ -371,7 +371,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
...
@@ -371,7 +371,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
nlon_in
=
ctx
.
nlon_in
nlon_in
=
ctx
.
nlon_in
nlat_out
=
ctx
.
nlat_out
nlat_out
=
ctx
.
nlat_out
nlon_out
=
ctx
.
nlon_out
nlon_out
=
ctx
.
nlon_out
kw
=
F
.
conv2d
(
k
,
weight
=
wk
,
bias
=
bk
)
kw
=
F
.
conv2d
(
k
,
weight
=
wk
,
bias
=
bk
)
vw
=
F
.
conv2d
(
v
,
weight
=
wv
,
bias
=
bv
)
vw
=
F
.
conv2d
(
v
,
weight
=
wv
,
bias
=
bv
)
qw
=
F
.
conv2d
(
q
,
weight
=
wq
,
bias
=
bq
)
qw
=
F
.
conv2d
(
q
,
weight
=
wq
,
bias
=
bq
)
...
@@ -408,7 +408,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
...
@@ -408,7 +408,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
dvw
=
dvw
.
reshape
(
B
,
-
1
,
H
,
W
)
dvw
=
dvw
.
reshape
(
B
,
-
1
,
H
,
W
)
_
,
C
,
H
,
W
=
dqw
.
shape
_
,
C
,
H
,
W
=
dqw
.
shape
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
dqw
=
dqw
.
reshape
(
B
,
-
1
,
H
,
W
)
# input grads
# input grads
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dv
=
torch
.
nn
.
functional
.
conv2d
(
dvw
,
weight
=
wv
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
dk
=
torch
.
nn
.
functional
.
conv2d
(
dkw
,
weight
=
wk
.
permute
([
1
,
0
,
2
,
3
]),
bias
=
None
)
...
@@ -439,13 +439,13 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
...
@@ -439,13 +439,13 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
None
,
None
,
None
,
None
,
None
,
None
,
None
def
_neighborhood_attention_s2_torch
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
def
_neighborhood_attention_s2_torch
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
return
_NeighborhoodAttentionS2
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
return
_NeighborhoodAttentionS2
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
quad_weights
,
col_idx
,
row_off
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
...
@@ -457,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -457,7 +457,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
def
forward
(
ctx
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
...
@@ -479,12 +479,12 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -479,12 +479,12 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
vw
=
vw
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
vw
=
vw
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
B
,
_
,
H
,
W
=
qw
.
shape
B
,
_
,
H
,
W
=
qw
.
shape
qw
=
qw
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
qw
=
qw
.
reshape
(
B
*
nh
,
-
1
,
H
,
W
)
# convert to float32
# convert to float32
kw
=
kw
.
to
(
torch
.
float32
)
kw
=
kw
.
to
(
torch
.
float32
)
vw
=
vw
.
to
(
torch
.
float32
)
vw
=
vw
.
to
(
torch
.
float32
)
qw
=
qw
.
to
(
torch
.
float32
)
qw
=
qw
.
to
(
torch
.
float32
)
output
=
attention_cuda_extension
.
forward
(
kw
,
vw
,
qw
,
quad_weights
,
output
=
attention_cuda_extension
.
forward
(
kw
,
vw
,
qw
,
quad_weights
,
col_idx
,
row_off
,
col_idx
,
row_off
,
nlon_in
,
nlat_out
,
nlon_out
)
nlon_in
,
nlat_out
,
nlon_out
)
...
@@ -561,13 +561,13 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -561,13 +561,13 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
_neighborhood_attention_s2_cuda
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
def
_neighborhood_attention_s2_cuda
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
wk
:
torch
.
Tensor
,
wv
:
torch
.
Tensor
,
wq
:
torch
.
Tensor
,
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
return
_NeighborhoodAttentionS2Cuda
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
return
_NeighborhoodAttentionS2Cuda
.
apply
(
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
,
quad_weights
,
col_idx
,
row_off
,
max_psi_nnz
,
quad_weights
,
col_idx
,
row_off
,
max_psi_nnz
,
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
nh
,
nlon_in
,
nlat_out
,
nlon_out
)
torch_harmonics/attention.py
View file @
ae8257b5
# coding=utf-8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 202
4
The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 202
5
The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# Redistribution and use in source and binary forms, with or without
...
...
torch_harmonics/csrc/attention/attention_row_offset.cu
View file @
ae8257b5
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ATen/core/TensorAccessor.h"
#include "ATen/core/TensorAccessor.h"
#include <cmath>
#include <cmath>
#include <cstdint>
#include <cstdint>
...
...
torch_harmonics/plotting.py
View file @
ae8257b5
# coding=utf-8
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 202
2
The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 202
5
The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# Redistribution and use in source and binary forms, with or without
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment