Commit 913e80d4 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in disco convolution

parent 313b1b73
......@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r"""
Forward pass for CUDA S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
......@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r"""
Backward pass for CUDA S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradient of the input
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
......@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters:
x: input tensor
roff_idx: row offset indices for sparse computation
ker_idx: kernel indices
row_idx: row indices for sparse computation
col_idx: column indices for sparse computation
vals: values for sparse computation
kernel_size: size of the kernel
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
......@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradient of the input
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype
......@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3
assert len(x.shape) == 4
......@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
"""
assert len(psi.shape) == 3
assert len(x.shape) == 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment