Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
913e80d4
Commit
913e80d4
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in disco convolution
parent
313b1b73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
28 deletions
+88
-28
torch_harmonics/_disco_convolution.py
torch_harmonics/_disco_convolution.py
+88
-28
No files found.
torch_harmonics/_disco_convolution.py
View file @
913e80d4
...
@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
...
@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r
"""
r
"""
Forward pass for CUDA S2 convolution contraction.
Forward pass for CUDA S2 convolution contraction.
Parameters:
Parameters
x: input tensor
-----------
roff_idx: row offset indices for sparse computation
ctx: torch.autograd.function.Context
ker_idx: kernel indices
Context object
row_idx: row indices for sparse computation
x: torch.Tensor
col_idx: column indices for sparse computation
Input tensor
vals: values for sparse computation
roff_idx: torch.Tensor
kernel_size: size of the kernel
Row offset indices for sparse computation
nlat_out: number of output latitude points
ker_idx: torch.Tensor
nlon_out: number of output longitude points
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
kernel_size
=
kernel_size
...
@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
...
@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r
"""
r
"""
Backward pass for CUDA S2 convolution contraction.
Backward pass for CUDA S2 convolution contraction.
Parameters:
Parameters
grad_output: gradient of the output
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
Returns
gradient of the input
--------
grad_input: torch.Tensor
Gradient of the input
"""
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
gtype
=
grad_output
.
dtype
...
@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
...
@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r
"""
r
"""
Forward pass for CUDA transpose S2 convolution contraction.
Forward pass for CUDA transpose S2 convolution contraction.
Parameters:
Parameters
x: input tensor
-----------
roff_idx: row offset indices for sparse computation
ctx: torch.autograd.function.Context
ker_idx: kernel indices
Context object
row_idx: row indices for sparse computation
x: torch.Tensor
col_idx: column indices for sparse computation
Input tensor
vals: values for sparse computation
roff_idx: torch.Tensor
kernel_size: size of the kernel
Row offset indices for sparse computation
nlat_out: number of output latitude points
ker_idx: torch.Tensor
nlon_out: number of output longitude points
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
"""
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
save_for_backward
(
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
)
ctx
.
kernel_size
=
kernel_size
ctx
.
kernel_size
=
kernel_size
...
@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
...
@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r
"""
r
"""
Backward pass for CUDA transpose S2 convolution contraction.
Backward pass for CUDA transpose S2 convolution contraction.
Parameters:
Parameters
grad_output: gradient of the output
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
Returns
gradient of the input
--------
grad_input: torch.Tensor
Gradient of the input
"""
"""
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
roff_idx
,
ker_idx
,
row_idx
,
col_idx
,
vals
=
ctx
.
saved_tensors
gtype
=
grad_output
.
dtype
gtype
=
grad_output
.
dtype
...
@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
...
@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
Reference implementation of the custom contraction as described in [1]. This requires repeated
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
4
assert
len
(
x
.
shape
)
==
4
...
@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
...
@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
Reference implementation of the custom contraction as described in [1]. This requires repeated
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
"""
assert
len
(
psi
.
shape
)
==
3
assert
len
(
psi
.
shape
)
==
3
assert
len
(
x
.
shape
)
==
5
assert
len
(
x
.
shape
)
==
5
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment