Commit 328200ab authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from backward passes

parent d70dee87
...@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cpu") @custom_bwd(device_type="cpu")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CPU neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
nlon_in = ctx.nlon_in nlon_in = ctx.nlon_in
...@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CUDA neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
max_psi_nnz = ctx.max_psi_nnz max_psi_nnz = ctx.max_psi_nnz
......
...@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
r"""
Backward pass for distributed azimuthal transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
...@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
r"""
Backward pass for distributed polar transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
...@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for copying to polar region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_polar(): if is_distributed_polar():
return _reduce(grad_output, group=polar_group()) return _reduce(grad_output, group=polar_group())
else: else:
...@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for copying to azimuth region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_azimuth(): if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group()) return _reduce(grad_output, group=azimuth_group())
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment