Commit 6373534a authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from autograd functions

parent 95fc83a0
...@@ -62,43 +62,13 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl ...@@ -62,43 +62,13 @@ def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nl
class _DiscoS2ContractionCuda(torch.autograd.Function): class _DiscoS2ContractionCuda(torch.autograd.Function):
r"""
CUDA implementation of the discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the S2 convolution operation using custom CUDA kernels.
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA S2 convolution contraction.
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
...@@ -113,19 +83,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -113,19 +83,7 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CUDA S2 convolution contraction.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous() grad_output = grad_output.to(torch.float32).contiguous()
...@@ -137,43 +95,13 @@ class _DiscoS2ContractionCuda(torch.autograd.Function): ...@@ -137,43 +95,13 @@ class _DiscoS2ContractionCuda(torch.autograd.Function):
class _DiscoS2TransposeContractionCuda(torch.autograd.Function): class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
r"""
CUDA implementation of the transpose discrete-continuous convolution contraction on the sphere.
This class provides the forward and backward passes for efficient GPU computation
of the transpose S2 convolution operation using custom CUDA kernels.
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int): kernel_size: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA transpose S2 convolution contraction.
Parameters
-----------
ctx: torch.autograd.function.Context
Context object
x: torch.Tensor
Input tensor
roff_idx: torch.Tensor
Row offset indices for sparse computation
ker_idx: torch.Tensor
Kernel indices
row_idx: torch.Tensor
Row indices for sparse computation
col_idx: torch.Tensor
Column indices for sparse computation
vals: torch.Tensor
Values for sparse computation
kernel_size: int
Size of the kernel
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2] ctx.nlat_in = x.shape[-2]
...@@ -188,19 +116,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function): ...@@ -188,19 +116,7 @@ class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CUDA transpose S2 convolution contraction.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
grad_input: torch.Tensor
Gradient of the input
"""
roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
gtype = grad_output.dtype gtype = grad_output.dtype
grad_output = grad_output.to(torch.float32).contiguous() grad_output = grad_output.to(torch.float32).contiguous()
......
...@@ -459,11 +459,6 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, ...@@ -459,11 +459,6 @@ def _neighborhood_attention_s2_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor,
return dqy return dqy
class _NeighborhoodAttentionS2(torch.autograd.Function): class _NeighborhoodAttentionS2(torch.autograd.Function):
r"""
CPU implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient CPU computation
of neighborhood attention operations using sparse tensor operations.
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cpu") @custom_fwd(device_type="cpu")
...@@ -472,44 +467,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -472,44 +467,7 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nh: int, nlon_in: int, nlat_out: int, nlon_out: int): nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CPU neighborhood attention on S2.
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
ctx.nlon_in = nlon_in ctx.nlon_in = nlon_in
...@@ -704,11 +662,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch. ...@@ -704,11 +662,7 @@ def _neighborhood_attention_s2_torch(k: torch.Tensor, v: torch.Tensor, q: torch.
class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r"""
CUDA implementation of neighborhood attention on the sphere (S2).
This class provides the forward and backward passes for efficient GPU computation
of neighborhood attention operations using custom CUDA kernels.
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
...@@ -717,46 +671,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -717,46 +671,7 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int): max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int):
r"""
Forward pass for CUDA neighborhood attention on S2.
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
max_psi_nnz: int
Maximum number of non-zero elements in sparse tensor
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
ctx.max_psi_nnz = max_psi_nnz ctx.max_psi_nnz = max_psi_nnz
......
...@@ -146,31 +146,6 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False) ...@@ -146,31 +146,6 @@ def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False)
class distributed_transpose_azimuth(torch.autograd.Function): class distributed_transpose_azimuth(torch.autograd.Function):
r"""
Distributed transpose operation for azimuthal dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the azimuthal dimension.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x_recv: List[torch.Tensor]
The split tensors
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
...@@ -226,29 +201,6 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -226,29 +201,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
class distributed_transpose_polar(torch.autograd.Function): class distributed_transpose_polar(torch.autograd.Function):
r"""
Distributed transpose operation for polar dimension.
This class provides the forward and backward passes for distributed
tensor transposition along the polar dimension.
Parameters
----------
x: torch.Tensor
The tensor to transpose
dims: List[int]
The dimensions to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
x: torch.Tensor
The transposed tensor
dim0_split_sizes: List[int]
The split sizes for the first dimension
req: dist.Request
The request object
"""
@staticmethod @staticmethod
@custom_fwd(device_type="cuda") @custom_fwd(device_type="cuda")
...@@ -403,21 +355,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None): ...@@ -403,21 +355,6 @@ def _reduce_scatter(input_, dim_, use_fp32=True, group=None):
class _CopyToPolarRegion(torch.autograd.Function): class _CopyToPolarRegion(torch.autograd.Function):
r"""
Copy tensor to polar region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to copy
Returns
-------
output: torch.Tensor
The reduced and scattered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -464,12 +401,6 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -464,12 +401,6 @@ class _CopyToPolarRegion(torch.autograd.Function):
class _CopyToAzimuthRegion(torch.autograd.Function): class _CopyToAzimuthRegion(torch.autograd.Function):
r"""
Copy tensor to azimuth region for distributed computation.
This class provides the forward and backward passes for copying
tensors to the azimuth region in distributed settings.
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -516,23 +447,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -516,23 +447,6 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
class _ScatterToPolarRegion(torch.autograd.Function): class _ScatterToPolarRegion(torch.autograd.Function):
r"""
Scatter tensor to polar region for distributed computation.
This class provides the forward and backward passes for scattering
tensors to the polar region in distributed settings.
Parameters
----------
input_: torch.Tensor
The tensor to scatter
dim_: int
The dimension to scatter along
Returns
-------
output: torch.Tensor
The scattered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def symbolic(graph, input_, dim_):
...@@ -560,23 +474,7 @@ class _ScatterToPolarRegion(torch.autograd.Function): ...@@ -560,23 +474,7 @@ class _ScatterToPolarRegion(torch.autograd.Function):
class _GatherFromPolarRegion(torch.autograd.Function): class _GatherFromPolarRegion(torch.autograd.Function):
r"""
Gather the input and keep it on the rank.
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
return _gather(input_, dim_, shapes_, polar_group()) return _gather(input_, dim_, shapes_, polar_group())
...@@ -600,19 +498,6 @@ class _GatherFromPolarRegion(torch.autograd.Function): ...@@ -600,19 +498,6 @@ class _GatherFromPolarRegion(torch.autograd.Function):
class _ReduceFromPolarRegion(torch.autograd.Function): class _ReduceFromPolarRegion(torch.autograd.Function):
r"""
All-reduce the input from the polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
...@@ -636,19 +521,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function): ...@@ -636,19 +521,7 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
class _ReduceFromAzimuthRegion(torch.autograd.Function): class _ReduceFromAzimuthRegion(torch.autograd.Function):
r"""
All-reduce the input from the azimuth region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
if is_distributed_azimuth(): if is_distributed_azimuth():
...@@ -671,21 +544,7 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function): ...@@ -671,21 +544,7 @@ class _ReduceFromAzimuthRegion(torch.autograd.Function):
class _ReduceFromScatterToPolarRegion(torch.autograd.Function): class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
r"""
All-reduce the input from the polar region and scatter back to polar region.
Parameters
----------
input_: torch.Tensor
The tensor to reduce
dim_: int
The dimension to reduce along
Returns
-------
output: torch.Tensor
The reduced tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def symbolic(graph, input_, dim_):
if is_distributed_polar(): if is_distributed_polar():
...@@ -715,23 +574,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function): ...@@ -715,23 +574,6 @@ class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
class _GatherFromCopyToPolarRegion(torch.autograd.Function): class _GatherFromCopyToPolarRegion(torch.autograd.Function):
r"""
Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter
Parameters
----------
input_: torch.Tensor
The tensor to gather
dim_: int
The dimension to gather along
shapes_: List[int]
The split sizes for the dimension to gather along
Returns
-------
output: torch.Tensor
The gathered tensor
"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_, shapes_): def symbolic(graph, input_, dim_, shapes_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment