Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
ffee67f9
Commit
ffee67f9
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in neighborhood attention
parent
913e80d4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
43 deletions
+118
-43
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+118
-43
No files found.
torch_harmonics/_neighborhood_attention.py
View file @
ffee67f9
...
...
@@ -475,23 +475,40 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r
"""
Forward pass for CPU neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
...
...
@@ -530,11 +547,31 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
r
"""
Backward pass for CPU neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradients for all input tensors and parameters
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
...
...
@@ -683,24 +720,42 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r
"""
Forward pass for CUDA neighborhood attention on S2.
Parameters:
k: key tensor
v: value tensor
q: query tensor
wk: key weight tensor
wv: value weight tensor
wq: query weight tensor
bk: key bias tensor (optional)
bv: value bias tensor (optional)
bq: query bias tensor (optional)
quad_weights: quadrature weights for spherical integration
col_idx: column indices for sparse computation
row_off: row offsets for sparse computation
max_psi_nnz: maximum number of non-zero elements in sparse tensor
nh: number of attention heads
nlon_in: number of input longitude points
nlat_out: number of output latitude points
nlon_out: number of output longitude points
Parameters
-----------
k: torch.Tensor
Key tensor
v: torch.Tensor
Value tensor
q: torch.Tensor
Query tensor
wk: torch.Tensor
Key weight tensor
wv: torch.Tensor
Value weight tensor
wq: torch.Tensor
Query weight tensor
bk: torch.Tensor or None
Key bias tensor (optional)
bv: torch.Tensor or None
Value bias tensor (optional)
bq: torch.Tensor or None
Query bias tensor (optional)
quad_weights: torch.Tensor
Quadrature weights for spherical integration
col_idx: torch.Tensor
Column indices for sparse computation
row_off: torch.Tensor
Row offsets for sparse computation
max_psi_nnz: int
Maximum number of non-zero elements in sparse tensor
nh: int
Number of attention heads
nlon_in: int
Number of input longitude points
nlat_out: int
Number of output latitude points
nlon_out: int
Number of output longitude points
"""
ctx
.
save_for_backward
(
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
)
ctx
.
nh
=
nh
...
...
@@ -745,11 +800,31 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
r
"""
Backward pass for CUDA neighborhood attention on S2.
Parameters:
grad_output: gradient of the output
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns:
gradients for all input tensors and parameters
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment