Commit ffee67f9 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in neighborhood attention

parent 913e80d4
...@@ -475,23 +475,40 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -475,23 +475,40 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r""" r"""
Forward pass for CPU neighborhood attention on S2. Forward pass for CPU neighborhood attention on S2.
Parameters: Parameters
k: key tensor -----------
v: value tensor k: torch.Tensor
q: query tensor Key tensor
wk: key weight tensor v: torch.Tensor
wv: value weight tensor Value tensor
wq: query weight tensor q: torch.Tensor
bk: key bias tensor (optional) Query tensor
bv: value bias tensor (optional) wk: torch.Tensor
bq: query bias tensor (optional) Key weight tensor
quad_weights: quadrature weights for spherical integration wv: torch.Tensor
col_idx: column indices for sparse computation Value weight tensor
row_off: row offsets for sparse computation wq: torch.Tensor
nh: number of attention heads Query weight tensor
nlon_in: number of input longitude points bk: torch.Tensor or None
nlat_out: number of output latitude points Key bias tensor (optional)
nlon_out: number of output longitude points bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
""" """
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
...@@ -530,11 +547,31 @@ class _NeighborhoodAttentionS2(torch.autograd.Function): ...@@ -530,11 +547,31 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r""" r"""
Backward pass for CPU neighborhood attention on S2. Backward pass for CPU neighborhood attention on S2.
Parameters: Parameters
grad_output: gradient of the output -----------
grad_output: torch.Tensor
Gradient of the output
Returns: Returns
gradients for all input tensors and parameters --------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
""" """
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
...@@ -683,24 +720,42 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -683,24 +720,42 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r""" r"""
Forward pass for CUDA neighborhood attention on S2. Forward pass for CUDA neighborhood attention on S2.
Parameters: Parameters
k: key tensor -----------
v: value tensor k: torch.Tensor
q: query tensor Key tensor
wk: key weight tensor v: torch.Tensor
wv: value weight tensor Value tensor
wq: query weight tensor q: torch.Tensor
bk: key bias tensor (optional) Query tensor
bv: value bias tensor (optional) wk: torch.Tensor
bq: query bias tensor (optional) Key weight tensor
quad_weights: quadrature weights for spherical integration wv: torch.Tensor
col_idx: column indices for sparse computation Value weight tensor
row_off: row offsets for sparse computation wq: torch.Tensor
max_psi_nnz: maximum number of non-zero elements in sparse tensor Query weight tensor
nh: number of attention heads bk: torch.Tensor or None
nlon_in: number of input longitude points Key bias tensor (optional)
nlat_out: number of output latitude points bv: torch.Tensor or None
nlon_out: number of output longitude points Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
max_psi_nnz: int
Maximum number of non-zero elements in sparse tensor
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
""" """
ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
ctx.nh = nh ctx.nh = nh
...@@ -745,11 +800,31 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -745,11 +800,31 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r""" r"""
Backward pass for CUDA neighborhood attention on S2. Backward pass for CUDA neighborhood attention on S2.
Parameters: Parameters
grad_output: gradient of the output -----------
grad_output: torch.Tensor
Gradient of the output
Returns: Returns
gradients for all input tensors and parameters --------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
""" """
col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
nh = ctx.nh nh = ctx.nh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment