Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
34927a33
Commit
34927a33
authored
Dec 15, 2024
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
adding better filter basis normalization mode
parent
7286a0d6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
1 deletion
+63
-1
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+63
-1
No files found.
torch_harmonics/filter_basis.py
View file @
34927a33
...
@@ -41,6 +41,8 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
...
@@ -41,6 +41,8 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
if
basis_type
==
"piecewise linear"
:
if
basis_type
==
"piecewise linear"
:
return
PiecewiseLinearFilterBasis
(
kernel_shape
=
kernel_shape
)
return
PiecewiseLinearFilterBasis
(
kernel_shape
=
kernel_shape
)
elif
basis_type
==
"disk morlet"
:
return
DiskMorletFilterBasis
(
kernel_shape
=
kernel_shape
)
else
:
else
:
raise
ValueError
(
f
"Unknown basis_type
{
basis_type
}
"
)
raise
ValueError
(
f
"Unknown basis_type
{
basis_type
}
"
)
...
@@ -85,7 +87,7 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
...
@@ -85,7 +87,7 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
if
len
(
kernel_shape
)
==
1
:
if
len
(
kernel_shape
)
==
1
:
kernel_shape
=
[
kernel_shape
[
0
],
1
]
kernel_shape
=
[
kernel_shape
[
0
],
1
]
elif
len
(
kernel_shape
)
!=
2
:
elif
len
(
kernel_shape
)
!=
2
:
raise
ValueError
(
f
"expected kernel_shape to be a list or tuple of length 1 or 2 bu
u
got
{
kernel_shape
}
instead."
)
raise
ValueError
(
f
"expected kernel_shape to be a list or tuple of length 1 or 2 bu
t
got
{
kernel_shape
}
instead."
)
super
().
__init__
(
kernel_shape
=
kernel_shape
)
super
().
__init__
(
kernel_shape
=
kernel_shape
)
...
@@ -187,3 +189,63 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
...
@@ -187,3 +189,63 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
return
self
.
_compute_support_vals_anisotropic
(
r
,
phi
,
r_cutoff
=
r_cutoff
)
return
self
.
_compute_support_vals_anisotropic
(
r
,
phi
,
r_cutoff
=
r_cutoff
)
else
:
else
:
return
self
.
_compute_support_vals_isotropic
(
r
,
phi
,
r_cutoff
=
r_cutoff
)
return
self
.
_compute_support_vals_isotropic
(
r
,
phi
,
r_cutoff
=
r_cutoff
)
class
DiskMorletFilterBasis
(
AbstractFilterBasis
):
"""
Morlet-like Filter basis. A Gaussian is multiplied with a Fourier basis in x and y.
"""
def
__init__
(
self
,
kernel_shape
:
Union
[
int
,
List
[
int
],
Tuple
[
int
,
int
]],
):
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
,
kernel_shape
]
if
len
(
kernel_shape
)
!=
2
:
raise
ValueError
(
f
"expected kernel_shape to be a list or tuple of 2 but got
{
kernel_shape
}
instead."
)
super
().
__init__
(
kernel_shape
=
kernel_shape
)
@
property
def
kernel_size
(
self
):
return
self
.
kernel_shape
[
0
]
*
self
.
kernel_shape
[
1
]
def
_gaussian_envelope
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
return
1
/
(
2
*
math
.
pi
*
width
**
2
)
*
torch
.
exp
(
-
0.5
*
r
**
2
/
(
width
**
2
))
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel
=
torch
.
arange
(
self
.
kernel_size
).
reshape
(
-
1
,
1
,
1
)
nkernel
=
ikernel
%
self
.
kernel_shape
[
1
]
mkernel
=
ikernel
//
self
.
kernel_shape
[
1
]
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
# # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
width
=
0.01
# width = 1.0
# envelope = self._gaussian_envelope(r, width=0.25 * r_cutoff)
# get x and y
x
=
r
*
torch
.
sin
(
phi
)
/
r_cutoff
y
=
r
*
torch
.
cos
(
phi
)
/
r_cutoff
harmonic
=
torch
.
where
(
nkernel
%
2
==
1
,
torch
.
sin
(
torch
.
ceil
(
nkernel
/
2
)
*
math
.
pi
*
x
/
width
),
torch
.
cos
(
torch
.
ceil
(
nkernel
/
2
)
*
math
.
pi
*
x
/
width
))
harmonic
*=
torch
.
where
(
mkernel
%
2
==
1
,
torch
.
sin
(
torch
.
ceil
(
mkernel
/
2
)
*
math
.
pi
*
y
/
width
),
torch
.
cos
(
torch
.
ceil
(
mkernel
/
2
)
*
math
.
pi
*
y
/
width
))
# disk area
# disk_area = 2.0 * math.pi * (1.0 - math.cos(r_cutoff))
disk_area
=
1
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals
=
self
.
_gaussian_envelope
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
,
width
=
width
)
*
harmonic
[
iidx
[:,
0
],
iidx
[:,
1
],
iidx
[:,
2
]]
/
disk_area
vals
=
torch
.
ones_like
(
vals
)
return
iidx
,
vals
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment