Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
913e80d4
Commit
913e80d4
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in disco convolution
parent
313b1b73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
28 deletions
+88
-28
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+88
-28
No files found.
torch_harmonics/_disco_convolution.py
View file @
913e80d4
...
...
@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r
"""
Forward pass for CUDA S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
...
...
@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r
"""
Backward pass for CUDA S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradient of the input
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
...
...
@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r
"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
...
...
@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r
"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradient of the input
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
...
...
@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
4
...
...
@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment