Commit ffee67f9 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in neighborhood attention

parent 913e80d4
......@@ -475,23 +475,40 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r"""
Forward pass for CPU neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
......@@ -530,11 +547,31 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r"""
Backward pass for CPU neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradients for all input tensors and parameters
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
......@@ -683,24 +720,42 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r"""
Forward pass for CUDA neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
max_psi_nnz: maximum number of non-zero elements in sparse tensor
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
max_psi_nnz: int
Maximum number of non-zero elements in sparse tensor
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh
......@@ -745,11 +800,31 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r"""
Backward pass for CUDA neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradients for all input tensors and parameters
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment