Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
328200ab
You need to sign in or sign up before continuing.
Commit
328200ab
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
removed docstrings from backward passes
parent
d70dee87
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
110 deletions
+3
-110
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+0
-58
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+3
-52
No files found.
torch_harmonics/_neighborhood_attention.py
View file @
328200ab
...
@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
...
@@ -502,35 +502,6 @@ class _NeighborhoodAttentionS2(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cpu"
)
@
custom_bwd
(
device_type
=
"cpu"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CPU neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
nh
=
ctx
.
nh
nlon_in
=
ctx
.
nlon_in
nlon_in
=
ctx
.
nlon_in
...
@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
...
@@ -712,35 +683,6 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for CUDA neighborhood attention on S2.
Parameters
-----------
grad_output: torch.Tensor
Gradient of the output
Returns
--------
dk: torch.Tensor
Gradient of the key tensor
dv: torch.Tensor
Gradient of the value tensor
dq: torch.Tensor
Gradient of the query tensor
dwk: torch.Tensor
Gradient of the key weight tensor
dwv: torch.Tensor
Gradient of the value weight tensor
dwq: torch.Tensor
Gradient of the query weight tensor
dbk: torch.Tensor or None
Gradient of the key bias tensor
dbv: torch.Tensor or None
Gradient of the value bias tensor
dbq: torch.Tensor or None
Gradient of the query bias tensor
"""
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
col_idx
,
row_off
,
quad_weights
,
k
,
v
,
q
,
wk
,
wv
,
wq
,
bk
,
bv
,
bq
=
ctx
.
saved_tensors
nh
=
ctx
.
nh
nh
=
ctx
.
nh
max_psi_nnz
=
ctx
.
max_psi_nnz
max_psi_nnz
=
ctx
.
max_psi_nnz
...
...
torch_harmonics/distributed/primitives.py
View file @
328200ab
...
@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
...
@@ -162,19 +162,6 @@ class distributed_transpose_azimuth(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
go
):
def
backward
(
ctx
,
go
):
r
"""
Backward pass for distributed azimuthal transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dims
=
ctx
.
dims
dims
=
ctx
.
dims
dim0_split_sizes
=
ctx
.
dim0_split_sizes
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
...
@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function):
...
@@ -200,19 +187,7 @@ class distributed_transpose_polar(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
go
):
def
backward
(
ctx
,
go
):
r
"""
Backward pass for distributed polar transpose.
Parameters
----------
go: torch.Tensor
The gradient of the output
Returns
-------
gi: torch.Tensor
The gradient of the input
"""
dim
=
ctx
.
dim
dim
=
ctx
.
dim
dim0_split_sizes
=
ctx
.
dim0_split_sizes
dim0_split_sizes
=
ctx
.
dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
# WAR for a potential contig check torch bug for channels last contig tensors
...
@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
...
@@ -337,19 +312,7 @@ class _CopyToPolarRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for copying to polar region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if
is_distributed_polar
():
if
is_distributed_polar
():
return
_reduce
(
grad_output
,
group
=
polar_group
())
return
_reduce
(
grad_output
,
group
=
polar_group
())
else
:
else
:
...
@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
...
@@ -371,19 +334,7 @@ class _CopyToAzimuthRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
(
device_type
=
"cuda"
)
@
custom_bwd
(
device_type
=
"cuda"
)
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
r
"""
Backward pass for copying to azimuth region.
Parameters
----------
grad_output: torch.Tensor
The gradient of the output
Returns
-------
grad_output: torch.Tensor
The gradient of the output
"""
if
is_distributed_azimuth
():
if
is_distributed_azimuth
():
return
_reduce
(
grad_output
,
group
=
azimuth_group
())
return
_reduce
(
grad_output
,
group
=
azimuth_group
())
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment