Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
ec53e666
Commit
ec53e666
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
further cleanup
parent
1ef713bb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1 addition
and
56 deletions
+1
-56
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+1
-32
torch_harmonics/attention.py
torch_harmonics/attention.py
+0
-6
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+0
-6
torch_harmonics/sht.py
torch_harmonics/sht.py
+0
-12
No files found.
torch_harmonics/_disco_convolution.py
View file @
ec53e666
...
@@ -42,35 +42,7 @@ except ImportError as err:
...
@@ -42,35 +42,7 @@ except ImportError as err:
# some helper functions
# some helper functions
def
_get_psi
(
kernel_size
:
int
,
psi_idx
:
torch
.
Tensor
,
psi_vals
:
torch
.
Tensor
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
nlat_in_local
:
Optional
[
int
]
=
None
,
nlat_out_local
:
Optional
[
int
]
=
None
,
semi_transposed
:
Optional
[
bool
]
=
False
):
def
_get_psi
(
kernel_size
:
int
,
psi_idx
:
torch
.
Tensor
,
psi_vals
:
torch
.
Tensor
,
nlat_in
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
,
nlat_in_local
:
Optional
[
int
]
=
None
,
nlat_out_local
:
Optional
[
int
]
=
None
,
semi_transposed
:
Optional
[
bool
]
=
False
):
"""Creates a sparse tensor for spherical harmonic convolution operations.
"""Creates a sparse tensor for spherical harmonic convolution operations."""
This function constructs a sparse COO tensor from indices and values, with optional
semi-transposition for computational efficiency in spherical harmonic convolutions.
Args:
kernel_size: Number of kernel elements.
psi_idx: Tensor of shape (3, n_nonzero) containing the indices for the sparse tensor.
The three dimensions represent [kernel_idx, lat_idx, combined_lat_lon_idx].
psi_vals: Tensor of shape (n_nonzero,) containing the values for the sparse tensor.
nlat_in: Number of input latitude points.
nlon_in: Number of input longitude points.
nlat_out: Number of output latitude points.
nlon_out: Number of output longitude points.
nlat_in_local: Local number of input latitude points. If None, defaults to nlat_in.
nlat_out_local: Local number of output latitude points. If None, defaults to nlat_out.
semi_transposed: If True, performs a semi-transposition to facilitate computation
by flipping the longitude axis and reorganizing indices.
Returns:
torch.Tensor: A sparse COO tensor of shape (kernel_size, nlat_out_local, nlat_in_local * nlon)
where nlon is either nlon_in or nlon_out depending on semi_transposed flag.
The tensor is coalesced to remove duplicate indices.
Note:
When semi_transposed=True, the function performs a partial transpose operation
that flips the longitude axis and reorganizes the indices to facilitate
efficient spherical harmonic convolution computations.
"""
nlat_in_local
=
nlat_in_local
if
nlat_in_local
is
not
None
else
nlat_in
nlat_in_local
=
nlat_in_local
if
nlat_in_local
is
not
None
else
nlat_in
nlat_out_local
=
nlat_out_local
if
nlat_out_local
is
not
None
else
nlat_out
nlat_out_local
=
nlat_out_local
if
nlat_out_local
is
not
None
else
nlat_out
...
@@ -90,7 +62,6 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
...
@@ -90,7 +62,6 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class
_DiscoS2ContractionCuda
(
torch
.
autograd
.
Function
):
class
_DiscoS2ContractionCuda
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
...
@@ -123,7 +94,6 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
...
@@ -123,7 +94,6 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class
_DiscoS2TransposeContractionCuda
(
torch
.
autograd
.
Function
):
class
_DiscoS2TransposeContractionCuda
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
...
@@ -169,7 +139,6 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
...
@@ -169,7 +139,6 @@ def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor
def
_disco_s2_contraction_torch
(
x
:
torch
.
Tensor
,
psi
:
torch
.
Tensor
,
nlon_out
:
int
):
def
_disco_s2_contraction_torch
(
x
:
torch
.
Tensor
,
psi
:
torch
.
Tensor
,
nlon_out
:
int
):
assert
len
(
psi
.
shape
)
==
3
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
4
assert
len
(
x
.
shape
)
==
4
psi
=
psi
.
to
(
x
.
device
)
psi
=
psi
.
to
(
x
.
device
)
...
...
torch_harmonics/attention.py
View file @
ec53e666
...
@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
...
@@ -142,9 +142,6 @@ class AttentionS2(nn.Module):
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
...
@@ -317,9 +314,6 @@ class NeighborhoodAttentionS2(nn.Module):
self
.
proj_bias
=
None
self
.
proj_bias
=
None
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_channels=
{
self
.
in_channels
}
, out_channels=
{
self
.
out_channels
}
, k_channels=
{
self
.
k_channels
}
"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
torch_harmonics/convolution.py
View file @
ec53e666
...
@@ -501,9 +501,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
...
@@ -501,9 +501,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self
.
psi
=
_get_psi
(
self
.
kernel_size
,
self
.
psi_idx
,
self
.
psi_vals
,
self
.
nlat_in
,
self
.
nlon_in
,
self
.
nlat_out
,
self
.
nlon_out
)
self
.
psi
=
_get_psi
(
self
.
kernel_size
,
self
.
psi_idx
,
self
.
psi_vals
,
self
.
nlat_in
,
self
.
nlon_in
,
self
.
nlat_out
,
self
.
nlon_out
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_chans=
{
self
.
groupsize
*
self
.
groups
}
, out_chans=
{
self
.
weight
.
shape
[
0
]
}
, filter_basis=
{
self
.
filter_basis
}
, kernel_shape=
{
self
.
kernel_shape
}
, groups=
{
self
.
groups
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_chans=
{
self
.
groupsize
*
self
.
groups
}
, out_chans=
{
self
.
weight
.
shape
[
0
]
}
, filter_basis=
{
self
.
filter_basis
}
, kernel_shape=
{
self
.
kernel_shape
}
, groups=
{
self
.
groups
}
"
@
property
@
property
...
@@ -660,9 +657,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -660,9 +657,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self
.
psi_st
=
_get_psi
(
self
.
kernel_size
,
self
.
psi_idx
,
self
.
psi_vals
,
self
.
nlat_in
,
self
.
nlon_in
,
self
.
nlat_out
,
self
.
nlon_out
,
semi_transposed
=
True
)
self
.
psi_st
=
_get_psi
(
self
.
kernel_size
,
self
.
psi_idx
,
self
.
psi_vals
,
self
.
nlat_in
,
self
.
nlon_in
,
self
.
nlat_out
,
self
.
nlon_out
,
semi_transposed
=
True
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_chans=
{
self
.
groupsize
*
self
.
groups
}
, out_chans=
{
self
.
weight
.
shape
[
0
]
}
, filter_basis=
{
self
.
filter_basis
}
, kernel_shape=
{
self
.
kernel_shape
}
, groups=
{
self
.
groups
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
, in_chans=
{
self
.
groupsize
*
self
.
groups
}
, out_chans=
{
self
.
weight
.
shape
[
0
]
}
, filter_basis=
{
self
.
filter_basis
}
, kernel_shape=
{
self
.
kernel_shape
}
, groups=
{
self
.
groups
}
"
@
property
@
property
...
...
torch_harmonics/sht.py
View file @
ec53e666
...
@@ -118,9 +118,6 @@ class RealSHT(nn.Module):
...
@@ -118,9 +118,6 @@ class RealSHT(nn.Module):
self
.
register_buffer
(
"weights"
,
weights
,
persistent
=
False
)
self
.
register_buffer
(
"weights"
,
weights
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
@@ -223,9 +220,6 @@ class InverseRealSHT(nn.Module):
...
@@ -223,9 +220,6 @@ class InverseRealSHT(nn.Module):
self
.
register_buffer
(
"pct"
,
pct
,
persistent
=
False
)
self
.
register_buffer
(
"pct"
,
pct
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
@@ -332,9 +326,6 @@ class RealVectorSHT(nn.Module):
...
@@ -332,9 +326,6 @@ class RealVectorSHT(nn.Module):
self
.
register_buffer
(
"weights"
,
weights
,
persistent
=
False
)
self
.
register_buffer
(
"weights"
,
weights
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
@@ -449,9 +440,6 @@ class InverseRealVectorSHT(nn.Module):
...
@@ -449,9 +440,6 @@ class InverseRealVectorSHT(nn.Module):
self
.
register_buffer
(
"dpct"
,
dpct
,
persistent
=
False
)
self
.
register_buffer
(
"dpct"
,
dpct
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
return
f
"nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
"
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment