Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
6373534a
Commit
6373534a
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
removed docstrings from autograd functions
parent
95fc83a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
333 deletions
+6
-333
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+2
-86
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+1
-86
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+3
-161
No files found.
torch_harmonics/_disco_convolution.py
View file @
6373534a
...
...
@@ -62,43 +62,13 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class
_DiscoS2ContractionCuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of the discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the S2 convolution operation using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA S2 convolution contraction.
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -113,19 +83,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA S2 convolution contraction.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
@@ -137,43 +95,13 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class
_DiscoS2TransposeContractionCuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of the transpose discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the transpose S2 convolution operation using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
def
forward
(
ctx
,
x
:
torch
.
Tensor
,
roff_idx
:
torch
.
Tensor
,
ker_idx
:
torch
.
Tensor
,
row_idx
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
vals
:
torch
.
Tensor
,
kernel_size
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
nlat_in
=
x
.
shape
[
-
2
]
...
...
@@ -188,19 +116,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float32
).
contiguous
()
...
...
torch_harmonics/_neighborhood_attention.py
View file @
6373534a
...
...
@@ -459,11 +459,6 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
return
dqy
class
_NeighborhoodAttentionS2
(
torch
.
autograd
.
Function
):
r
"""
CPU implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient CPU computation
of neighborhood attention operations using sparse tensor operations.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cpu"
)
...
...
@@ -472,44 +467,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CPU neighborhood attention on S2.
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
nlon_in
=
nlon_in
...
...
@@ -704,11 +662,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class
_NeighborhoodAttentionS2Cuda
(
torch
.
autograd
.
Function
):
r
"""
CUDA implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient GPU computation
of neighborhood attention operations using custom CUDA kernels.
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
...
...
@@ -717,46 +671,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk
:
Union
[
torch
.
Tensor
,
None
],
bv
:
Union
[
torch
.
Tensor
,
None
],
bq
:
Union
[
torch
.
Tensor
,
None
],
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
max_psi_nnz
:
int
,
nh
:
int
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
):
r
"""
Forward pass for CUDA neighborhood attention on S2.
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
max_psi_nnz: int
Maximum number of non-zero elements in sparse tensor
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
ctx
.
max_psi_nnz
=
max_psi_nnz
...
...
torch_harmonics/distributed/primitives.py
View file @
6373534a
...
...
@@ -146,31 +146,6 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False)
class
distributed_transpose_azimuth
(
torch
.
autograd
.
Function
):
r
"""
Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x_recv: List[torch.Tensor]
The split tensors
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
...
...
@@ -226,29 +201,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
class
distributed_transpose_polar
(
torch
.
autograd
.
Function
):
r
"""
Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@
staticmethod
@
custom_fwd
(
device_type
=
"cuda"
)
...
...
@@ -403,21 +355,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class
_CopyToPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
output: torch.Tensor
The reduced and scattered tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
...
...
@@ -464,12 +401,6 @@ class _CopyToPolarRegion(torch.autograd.Function):
class
_CopyToAzimuthRegion
(
torch
.
autograd
.
Function
):
r
"""
Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
...
...
@@ -516,23 +447,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class
_ScatterToPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to scatter
dim_: int
The dimension to scatter along
Returns
-------
output: torch.Tensor
The scattered tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
):
...
...
@@ -560,23 +474,7 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class
_GatherFromPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
Gather the input and keep it on the rank.
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
,
shapes_
):
return
_gather
(
input_
,
dim_
,
shapes_
,
polar_group
())
...
...
@@ -600,19 +498,6 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class
_ReduceFromPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
All-reduce the input from the polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
...
...
@@ -636,19 +521,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class
_ReduceFromAzimuthRegion
(
torch
.
autograd
.
Function
):
r
"""
All-reduce the input from the azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
if
is_distributed_azimuth
():
...
...
@@ -671,21 +544,7 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class
_ReduceFromScatterToPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
All-reduce the input from the polar region and scatter back to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
dim_: int
The dimension to reduce along
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
):
if
is_distributed_polar
():
...
...
@@ -715,23 +574,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class
_GatherFromCopyToPolarRegion
(
torch
.
autograd
.
Function
):
r
"""
Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@
staticmethod
def
symbolic
(
graph
,
input_
,
dim_
,
shapes_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment