You need to sign in or sign up before continuing.
Commit 328200ab authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

removed docstrings from backward passes

parent d70dee87
...@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cpu") @custom_bwd(device_type="cpu")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CPU neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
nlon_in = ctx.nlon_in nlon_in = ctx.nlon_in
...@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for CUDA neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
max_psi_nnz = ctx.max_psi_nnz max_psi_nnz = ctx.max_psi_nnz
......
...@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function): ...@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
r"""
Backward pass for distributed azimuthal transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dims = ctx.dims dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
...@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function): ...@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, go): def backward(ctx, go):
r"""
Backward pass for distributed polar transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
...@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function): ...@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for copying to polar region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_polar(): if is_distributed_polar():
return _reduce(grad_output, group=polar_group()) return _reduce(grad_output, group=polar_group())
else: else:
...@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function): ...@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd(device_type="cuda") @custom_bwd(device_type="cuda")
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""
Backward pass for copying to azimuth region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if is_distributed_azimuth(): if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group()) return _reduce(grad_output, group=azimuth_group())
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment