Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
61dd6cf1
Commit
61dd6cf1
authored
Jul 21, 2025
by
Boris Bonev
Browse files
removing get_psi which got added from merge
parent
ba7a4996
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
39 deletions
+10
-39
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+10
-39
No files found.
torch_harmonics/convolution.py
View file @
61dd6cf1
...
@@ -61,7 +61,7 @@ def _normalize_convolution_tensor_s2(
...
@@ -61,7 +61,7 @@ def _normalize_convolution_tensor_s2(
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"mean"
,
merge_quadrature
=
False
,
eps
=
1e-9
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"mean"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
):
"""Normalizes convolution tensor values based on specified normalization mode.
"""Normalizes convolution tensor values based on specified normalization mode.
This function applies different normalization strategies to the convolution tensor
This function applies different normalization strategies to the convolution tensor
values based on the basis_norm_mode parameter. It can normalize individual basis
values based on the basis_norm_mode parameter. It can normalize individual basis
functions, compute mean normalization across all basis functions, or use support
functions, compute mean normalization across all basis functions, or use support
...
@@ -143,7 +143,6 @@ def _normalize_convolution_tensor_s2(
...
@@ -143,7 +143,6 @@ def _normalize_convolution_tensor_s2(
# compute the support
# compute the support
support
[
ik
,
ilat
]
=
torch
.
sum
(
q
[
iidx
])
support
[
ik
,
ilat
]
=
torch
.
sum
(
q
[
iidx
])
# loop over values and renormalize
# loop over values and renormalize
for
ik
in
range
(
kernel_size
):
for
ik
in
range
(
kernel_size
):
for
ilat
in
range
(
nlat_out
):
for
ilat
in
range
(
nlat_out
):
...
@@ -166,7 +165,6 @@ def _normalize_convolution_tensor_s2(
...
@@ -166,7 +165,6 @@ def _normalize_convolution_tensor_s2(
if
merge_quadrature
:
if
merge_quadrature
:
psi_vals
[
iidx
]
=
psi_vals
[
iidx
]
*
q
[
iidx
]
psi_vals
[
iidx
]
=
psi_vals
[
iidx
]
*
q
[
iidx
]
if
transpose_normalization
and
merge_quadrature
:
if
transpose_normalization
and
merge_quadrature
:
psi_vals
=
psi_vals
/
correction_factor
psi_vals
=
psi_vals
/
correction_factor
...
@@ -178,13 +176,13 @@ def _precompute_convolution_tensor_s2(
...
@@ -178,13 +176,13 @@ def _precompute_convolution_tensor_s2(
in_shape
:
Tuple
[
int
],
in_shape
:
Tuple
[
int
],
out_shape
:
Tuple
[
int
],
out_shape
:
Tuple
[
int
],
filter_basis
:
FilterBasis
,
filter_basis
:
FilterBasis
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
theta_cutoff
:
Optional
[
float
]
=
0.01
*
math
.
pi
,
theta_cutoff
:
Optional
[
float
]
=
0.01
*
math
.
pi
,
theta_eps
:
Optional
[
float
]
=
1e-3
,
theta_eps
:
Optional
[
float
]
=
1e-3
,
transpose_normalization
:
Optional
[
bool
]
=
False
,
transpose_normalization
:
Optional
[
bool
]
=
False
,
basis_norm_mode
:
Optional
[
str
]
=
"mean"
,
basis_norm_mode
:
Optional
[
str
]
=
"mean"
,
merge_quadrature
:
Optional
[
bool
]
=
False
,
merge_quadrature
:
Optional
[
bool
]
=
False
,
):
):
"""
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i
\n
u = Y(-
\t
heta_j)Z(\phi_i - \phi_j)Y(
\t
heta_j)
\n
u$.
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i
\n
u = Y(-
\t
heta_j)Z(\phi_i - \phi_j)Y(
\t
heta_j)
\n
u$.
...
@@ -515,18 +513,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
...
@@ -515,18 +513,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
"""
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
def
get_psi
(
self
):
"""
Get the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_in
)).
coalesce
()
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
is_cuda
and
_cuda_extension_available
:
if
x
.
is_cuda
and
_cuda_extension_available
:
...
@@ -582,7 +568,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -582,7 +568,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
Whether to use bias
Whether to use bias
theta_cutoff: Optional[float]
theta_cutoff: Optional[float]
Theta cutoff for the filter basis functions
Theta cutoff for the filter basis functions
Returns
Returns
--------
--------
out: torch.Tensor
out: torch.Tensor
...
@@ -663,23 +649,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -663,23 +649,8 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
def
psi_idx
(
self
):
def
psi_idx
(
self
):
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
def
get_psi
(
self
,
semi_transposed
:
bool
=
False
):
if
semi_transposed
:
# we do a semi-transposition to faciliate the computation
tout
=
self
.
psi_idx
[
2
]
//
self
.
nlon_out
pout
=
self
.
psi_idx
[
2
]
%
self
.
nlon_out
# flip the axis of longitudes
pout
=
self
.
nlon_out
-
1
-
pout
tin
=
self
.
psi_idx
[
1
]
idx
=
torch
.
stack
([
self
.
psi_idx
[
0
],
tout
,
tin
*
self
.
nlon_out
+
pout
],
dim
=
0
)
psi
=
torch
.
sparse_coo_tensor
(
idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_out
)).
coalesce
()
else
:
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_in
,
self
.
nlat_out
*
self
.
nlon_out
)).
coalesce
()
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# extract shape
# extract shape
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
(
B
,
self
.
groups
,
self
.
groupsize
,
H
,
W
)
x
=
x
.
reshape
(
B
,
self
.
groups
,
self
.
groupsize
,
H
,
W
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment