Commit 913e80d4 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in disco convolution

parent 313b1b73
...@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -76,16 +76,28 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r""" r"""
Forward pass for CUDA S2 convolution contraction. Forward pass for CUDA S2 convolution contraction.
Parameters: Parameters
x: input tensor -----------
roff_idx: row offset indices for sparse computation ctx: torch.autograd.function.Context
ker_idx: kernel indices Context object
row_idx: row indices for sparse computation x: torch.Tensor
col_idx: column indices for sparse computation Input tensor
vals: values for sparse computation roff_idx: torch.Tensor
kernel_size: size of the kernel Row offset indices for sparse computation
nlat_out: number of output latitude points ker_idx: torch.Tensor
nlon_out: number of output longitude points Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
""" """
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
...@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -104,11 +116,15 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
r""" r"""
Backward pass for CUDA S2 convolution contraction. Backward pass for CUDA S2 convolution contraction.
Parameters: Parameters
grad_output: gradient of the output -----------
grad_output: torch.Tensor
Gradient of the output
Returns: Returns
gradient of the input --------
grad_input: torch.Tensor
Gradient of the input
""" """
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
...@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -135,16 +151,28 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r""" r"""
Forward pass for CUDA transpose S2 convolution contraction. Forward pass for CUDA transpose S2 convolution contraction.
Parameters: Parameters
x: input tensor -----------
roff_idx: row offset indices for sparse computation ctx: torch.autograd.function.Context
ker_idx: kernel indices Context object
row_idx: row indices for sparse computation x: torch.Tensor
col_idx: column indices for sparse computation Input tensor
vals: values for sparse computation roff_idx: torch.Tensor
kernel_size: size of the kernel Row offset indices for sparse computation
nlat_out: number of output latitude points ker_idx: torch.Tensor
nlon_out: number of output longitude points Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
""" """
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
...@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -163,11 +191,15 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r""" r"""
Backward pass for CUDA transpose S2 convolution contraction. Backward pass for CUDA transpose S2 convolution contraction.
Parameters: Parameters
grad_output: gradient of the output -----------
grad_output: torch.Tensor
Gradient of the output
Returns: Returns
gradient of the input --------
grad_input: torch.Tensor
Gradient of the input
""" """
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
...@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -197,6 +229,20 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
Reference implementation of the custom contraction as described in [1]. This requires repeated Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA. on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
""" """
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 4 assert len(x.shape) == 4
...@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl ...@@ -233,6 +279,20 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
Reference implementation of the custom contraction as described in [1]. This requires repeated Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in CUDA. on GPU, make sure to use the custom kernel written in CUDA.
Parameters
-----------
x: torch.Tensor
Input tensor
psi: torch.Tensor
Kernel tensor
nlon_out: int
Number of output longitude points
Returns
--------
y: torch.Tensor
Output tensor
""" """
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 5 assert len(x.shape) == 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment