Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
KamNet_pytorch
Commits
b5881ee2
Commit
b5881ee2
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Changes
81
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3565 additions
and
0 deletions
+3565
-0
lie_learn/lie_learn/representations/SO3/test_SO3_irrep_bases.py
...arn/lie_learn/representations/SO3/test_SO3_irrep_bases.py
+166
-0
lie_learn/lie_learn/representations/SO3/test_spherical_harmonics.py
...lie_learn/representations/SO3/test_spherical_harmonics.py
+40
-0
lie_learn/lie_learn/representations/SO3/test_wigner_d.py
lie_learn/lie_learn/representations/SO3/test_wigner_d.py
+373
-0
lie_learn/lie_learn/representations/SO3/wigner_d.py
lie_learn/lie_learn/representations/SO3/wigner_d.py
+291
-0
lie_learn/lie_learn/representations/__init__.py
lie_learn/lie_learn/representations/__init__.py
+0
-0
lie_learn/lie_learn/spaces/S2.py
lie_learn/lie_learn/spaces/S2.py
+390
-0
lie_learn/lie_learn/spaces/S3.py
lie_learn/lie_learn/spaces/S3.py
+202
-0
lie_learn/lie_learn/spaces/Tn.py
lie_learn/lie_learn/spaces/Tn.py
+17
-0
lie_learn/lie_learn/spaces/__init__.py
lie_learn/lie_learn/spaces/__init__.py
+1
-0
lie_learn/lie_learn/spaces/rn.py
lie_learn/lie_learn/spaces/rn.py
+106
-0
lie_learn/lie_learn/spaces/spherical_quadrature.pyx
lie_learn/lie_learn/spaces/spherical_quadrature.pyx
+72
-0
lie_learn/lie_learn/spectral/FFTBase.py
lie_learn/lie_learn/spectral/FFTBase.py
+11
-0
lie_learn/lie_learn/spectral/PolarFFT.py
lie_learn/lie_learn/spectral/PolarFFT.py
+35
-0
lie_learn/lie_learn/spectral/S2FFT.py
lie_learn/lie_learn/spectral/S2FFT.py
+172
-0
lie_learn/lie_learn/spectral/S2FFT_NFFT.py
lie_learn/lie_learn/spectral/S2FFT_NFFT.py
+100
-0
lie_learn/lie_learn/spectral/S2_conv.py
lie_learn/lie_learn/spectral/S2_conv.py
+154
-0
lie_learn/lie_learn/spectral/SE2FFT.py
lie_learn/lie_learn/spectral/SE2FFT.py
+730
-0
lie_learn/lie_learn/spectral/SO3FFT_Naive.py
lie_learn/lie_learn/spectral/SO3FFT_Naive.py
+563
-0
lie_learn/lie_learn/spectral/SO3_conv.py
lie_learn/lie_learn/spectral/SO3_conv.py
+73
-0
lie_learn/lie_learn/spectral/T1FFT.py
lie_learn/lie_learn/spectral/T1FFT.py
+69
-0
No files found.
lie_learn/lie_learn/representations/SO3/test_SO3_irrep_bases.py
0 → 100755
View file @
b5881ee2
from
lie_learn.representations.SO3.irrep_bases
import
*
from
.spherical_harmonics
import
*
TEST_L_MAX
=
5
def
test_change_of_basis_matrix
():
"""
Testing if change of basis matrix is consistent with spherical harmonics functions
"""
for
l
in
range
(
TEST_L_MAX
):
theta
=
np
.
random
.
rand
()
*
np
.
pi
phi
=
np
.
random
.
rand
()
*
np
.
pi
*
2
for
from_field
in
[
'complex'
,
'real'
]:
for
from_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
from_cs
in
[
'cs'
,
'nocs'
]:
for
to_field
in
[
'complex'
,
'real'
]:
for
to_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
to_cs
in
[
'cs'
,
'nocs'
]:
Y_from
=
sh
(
l
,
np
.
arange
(
-
l
,
l
+
1
),
theta
,
phi
,
from_field
,
from_normalization
,
from_cs
==
'cs'
)
Y_to
=
sh
(
l
,
np
.
arange
(
-
l
,
l
+
1
),
theta
,
phi
,
to_field
,
to_normalization
,
to_cs
==
'cs'
)
B
=
change_of_basis_matrix
(
l
=
l
,
frm
=
(
from_field
,
from_normalization
,
'centered'
,
from_cs
),
to
=
(
to_field
,
to_normalization
,
'centered'
,
to_cs
))
print
(
from_field
,
from_normalization
,
from_cs
,
'->'
,
to_field
,
to_normalization
,
to_cs
,
np
.
sum
(
np
.
abs
(
B
.
dot
(
Y_from
)
-
Y_to
)))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
B
.
dot
(
Y_from
)
-
Y_to
)),
0.0
)
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
np
.
linalg
.
inv
(
B
).
dot
(
Y_to
)
-
Y_from
)),
0.0
)
def
test_change_of_basis_function
():
"""
Testing if change of basis function is consistent with spherical harmonics functions
"""
for
l
in
range
(
TEST_L_MAX
):
theta
=
np
.
random
.
rand
()
*
np
.
pi
phi
=
np
.
random
.
rand
()
*
np
.
pi
*
2
for
from_field
in
[
'complex'
,
'real'
]:
for
from_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
from_cs
in
[
'cs'
,
'nocs'
]:
for
to_field
in
[
'complex'
,
'real'
]:
for
to_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
to_cs
in
[
'cs'
,
'nocs'
]:
Y_from
=
sh
(
l
,
np
.
arange
(
-
l
,
l
+
1
),
theta
,
phi
,
from_field
,
from_normalization
,
from_cs
==
'cs'
)
Y_to
=
sh
(
l
,
np
.
arange
(
-
l
,
l
+
1
),
theta
,
phi
,
to_field
,
to_normalization
,
to_cs
==
'cs'
)
f
=
change_of_basis_function
(
l
=
l
,
frm
=
(
from_field
,
from_normalization
,
'centered'
,
from_cs
),
to
=
(
to_field
,
to_normalization
,
'centered'
,
to_cs
))
print
(
from_field
,
from_normalization
,
from_cs
,
'->'
,
to_field
,
to_normalization
,
to_cs
,
np
.
sum
(
np
.
abs
(
f
(
Y_from
)
-
Y_to
)))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
f
(
Y_from
)
-
Y_to
)),
0.0
)
def
test_change_of_basis_function_lists
():
"""
Testing change of basis function for spherical harmonics for multiple orders at once.
The change-of-basis function for spherical harmonics should be consistent with the CSH & RSH functions.
"""
l
=
np
.
arange
(
4
)
ls
=
np
.
array
([
0
,
1
,
1
,
1
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
3
,
3
,
3
])
ms
=
np
.
array
([
0
,
-
1
,
0
,
1
,
-
2
,
-
1
,
0
,
1
,
2
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
])
theta
=
np
.
random
.
rand
()
*
np
.
pi
phi
=
np
.
random
.
rand
()
*
np
.
pi
*
2
for
from_field
in
[
'complex'
,
'real'
]:
for
from_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
from_cs
in
[
'cs'
,
'nocs'
]:
for
to_field
in
[
'complex'
,
'real'
]:
for
to_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
to_cs
in
[
'cs'
,
'nocs'
]:
Y_from
=
sh
(
ls
,
ms
,
theta
,
phi
,
from_field
,
from_normalization
,
from_cs
==
'cs'
)
Y_to
=
sh
(
ls
,
ms
,
theta
,
phi
,
to_field
,
to_normalization
,
to_cs
==
'cs'
)
f
=
change_of_basis_function
(
l
=
l
,
frm
=
(
from_field
,
from_normalization
,
'centered'
,
from_cs
),
to
=
(
to_field
,
to_normalization
,
'centered'
,
to_cs
))
print
(
from_field
,
from_normalization
,
from_cs
,
'->'
,
to_field
,
to_normalization
,
to_cs
,
np
.
sum
(
np
.
abs
(
f
(
Y_from
)
-
Y_to
)))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
f
(
Y_from
)
-
Y_to
)),
0.0
)
def
test_invertibility
():
"""
Testing if change_of_basis_function for SO(3) is invertible
"""
for
l
in
range
(
TEST_L_MAX
):
theta
=
np
.
random
.
rand
()
*
np
.
pi
phi
=
np
.
random
.
rand
()
*
np
.
pi
*
2
for
from_field
in
[
'complex'
,
'real'
]:
for
from_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
from_cs
in
[
'cs'
,
'nocs'
]:
for
from_order
in
[
'centered'
,
'block'
]:
for
to_field
in
[
'complex'
,
'real'
]:
for
to_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
to_cs
in
[
'cs'
,
'nocs'
]:
for
to_order
in
[
'centered'
,
'block'
]:
# A truly complex function cannot be made real;
if
from_field
==
'complex'
and
to_field
==
'real'
:
continue
if
from_field
==
'complex'
:
Y
=
np
.
random
.
randn
(
2
*
l
+
1
)
+
np
.
random
.
randn
(
2
*
l
+
1
)
*
1j
else
:
Y
=
np
.
random
.
randn
(
2
*
l
+
1
)
f
=
change_of_basis_function
(
l
=
l
,
frm
=
(
from_field
,
from_normalization
,
from_order
,
from_cs
),
to
=
(
to_field
,
to_normalization
,
to_order
,
to_cs
))
f_inv
=
change_of_basis_function
(
l
=
l
,
frm
=
(
to_field
,
to_normalization
,
to_order
,
to_cs
),
to
=
(
from_field
,
from_normalization
,
from_order
,
from_cs
))
print
(
from_field
,
from_normalization
,
from_cs
,
from_order
,
'->'
,
to_field
,
to_normalization
,
to_cs
,
to_order
,
np
.
sum
(
np
.
abs
(
f_inv
(
f
(
Y
))
-
Y
)))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
f_inv
(
f
(
Y
))
-
Y
)),
0.
)
#assert np.isclose(np.sum(np.abs(f(f_inv(Y)) - Y)), 0.)
def
test_linearity_change_of_basis
():
"""
Testing that SO3 change of basis is indeed linear
"""
for
l
in
range
(
TEST_L_MAX
):
theta
=
np
.
random
.
rand
()
*
np
.
pi
phi
=
np
.
random
.
rand
()
*
np
.
pi
*
2
for
from_field
in
[
'complex'
,
'real'
]:
for
from_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
from_cs
in
[
'cs'
,
'nocs'
]:
for
from_order
in
[
'centered'
,
'block'
]:
for
to_field
in
[
'complex'
,
'real'
]:
for
to_normalization
in
[
'seismology'
,
'quantum'
,
'geodesy'
,
'nfft'
]:
for
to_cs
in
[
'cs'
,
'nocs'
]:
for
to_order
in
[
'centered'
,
'block'
]:
# A truly complex function cannot be made real;
if
from_field
==
'complex'
and
to_field
==
'real'
:
continue
Y1
=
np
.
random
.
randn
(
2
*
l
+
1
)
Y2
=
np
.
random
.
randn
(
2
*
l
+
1
)
a
=
np
.
random
.
randn
(
1
)
b
=
np
.
random
.
randn
(
1
)
f
=
change_of_basis_function
(
l
=
l
,
frm
=
(
from_field
,
from_normalization
,
from_order
,
from_cs
),
to
=
(
to_field
,
to_normalization
,
from_order
,
to_cs
))
print
(
from_field
,
from_normalization
,
from_cs
,
from_order
,
'->'
,
to_field
,
to_normalization
,
to_cs
,
to_order
,
np
.
sum
(
np
.
abs
(
a
*
f
(
Y1
)
+
b
*
f
(
Y2
)
-
f
(
a
*
Y1
+
b
*
Y2
))))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
a
*
f
(
Y1
)
+
b
*
f
(
Y2
)
-
f
(
a
*
Y1
+
b
*
Y2
))),
0.
)
lie_learn/lie_learn/representations/SO3/test_spherical_harmonics.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
import
lie_learn.spaces.S2
as
S2
from
lie_learn.representations.SO3.spherical_harmonics
import
sh
,
sh_squared_norm
def
check_orthogonality
(
L_max
=
3
,
grid_type
=
'Gauss-Legendre'
,
field
=
'real'
,
normalization
=
'quantum'
,
condon_shortley
=
True
):
theta
,
phi
=
S2
.
meshgrid
(
b
=
L_max
+
1
,
grid_type
=
grid_type
)
w
=
S2
.
quadrature_weights
(
b
=
L_max
+
1
,
grid_type
=
grid_type
)
for
l
in
range
(
L_max
):
for
m
in
range
(
-
l
,
l
+
1
):
for
l2
in
range
(
L_max
):
for
m2
in
range
(
-
l2
,
l2
+
1
):
Ylm
=
sh
(
l
,
m
,
theta
,
phi
,
field
,
normalization
,
condon_shortley
)
Ylm2
=
sh
(
l2
,
m2
,
theta
,
phi
,
field
,
normalization
,
condon_shortley
)
dot_numerical
=
S2
.
integrate_quad
(
Ylm
*
Ylm2
.
conj
(),
grid_type
=
grid_type
,
normalize
=
False
,
w
=
w
)
dot_numerical2
=
S2
.
integrate
(
lambda
t
,
p
:
sh
(
l
,
m
,
t
,
p
,
field
,
normalization
,
condon_shortley
)
*
\
sh
(
l2
,
m2
,
t
,
p
,
field
,
normalization
,
condon_shortley
).
conj
(),
normalize
=
False
)
sqnorm_analytical
=
sh_squared_norm
(
l
,
normalization
,
normalized_haar
=
False
)
dot_analytical
=
sqnorm_analytical
*
(
l
==
l2
and
m
==
m2
)
print
(
l
,
m
,
l2
,
m2
,
field
,
normalization
,
condon_shortley
,
dot_analytical
,
dot_numerical
,
dot_numerical2
)
assert
np
.
isclose
(
dot_numerical
,
dot_analytical
)
assert
np
.
isclose
(
dot_numerical2
,
dot_analytical
)
def
test_orthogonality
():
L_max
=
2
grid_type
=
'Gauss-Legendre'
for
field
in
(
'real'
,
'complex'
):
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
condon_shortley
in
(
True
,
False
):
check_orthogonality
(
L_max
,
grid_type
,
field
,
normalization
,
condon_shortley
)
lie_learn/lie_learn/representations/SO3/test_wigner_d.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
lie_learn.representations.SO3.wigner_d
import
wigner_D_matrix
,
wigner_d_matrix
,
\
wigner_d_naive
,
wigner_d_naive_v2
,
wigner_d_naive_v3
,
wigner_d_function
,
wigner_D_function
import
lie_learn.spaces.S3
as
S3
TEST_L_MAX
=
3
def
check_unitarity_wigner_D
():
"""
Check that the Wigner-D matrices are unitary.
We test every normalization convention and a range of input angles.
Note: only the quantum- or seismology normalized Wigner-D matrices are unitary,
so we do not check the geodesy and nfft normalized matrices.
"""
for
l
in
range
(
TEST_L_MAX
):
for
field
in
(
'real'
,
'complex'
):
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
order
in
(
'centered'
,
'block'
):
for
condon_shortley
in
(
'cs'
,
'nocs'
):
for
a
in
np
.
linspace
(
0
,
2
*
np
.
pi
,
10
):
for
b
in
np
.
linspace
(
0
,
np
.
pi
,
10
):
for
c
in
np
.
linspace
(
0
,
2
*
np
.
pi
,
10
):
m
=
wigner_D_matrix
(
l
,
a
,
b
,
c
,
field
,
normalization
,
order
,
condon_shortley
)
diff
=
np
.
abs
(
m
.
conj
().
T
.
dot
(
m
)
-
np
.
eye
(
m
.
shape
[
0
])).
sum
()
diff
+=
np
.
abs
(
m
.
dot
(
m
.
conj
().
T
)
-
np
.
eye
(
m
.
shape
[
0
])).
sum
()
print
(
l
,
field
,
normalization
,
order
,
condon_shortley
,
a
,
b
,
c
,
diff
)
assert
np
.
isclose
(
diff
,
0.
)
def
check_normalization_wigner_D
():
"""
According to [1], the Wigner D functions satisfy:
int_0^2pi da int_0^pi db sin(b) int_0^2pi |D^l_mn(a,b,c)|^2 = 8 pi^2 / (2l+1)
The factor 8 pi^2 is removed if we integrate with respect to the normalized Haar measure.
Here we test this equality by numerical integration.
NOTE: this test is subsumed in check_orthogonality_wigner_D, but that function is very slow
"""
w
=
S3
.
quadrature_weights
(
b
=
TEST_L_MAX
+
1
,
grid_type
=
'SOFT'
)
for
l
in
range
(
TEST_L_MAX
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
for
field
in
(
'real'
,
'complex'
):
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
order
in
(
'centered'
,
'block'
):
for
condon_shortley
in
(
'cs'
,
'nocs'
):
f
=
lambda
a
,
b
,
c
:
np
.
abs
(
wigner_D_function
(
l
=
l
,
m
=
m
,
n
=
n
,
alpha
=
a
,
beta
=
b
,
gamma
=
c
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
))
**
2
sqnorm_numerical
=
S3
.
integrate
(
f
,
normalize
=
True
)
D
=
make_D_sample_grid
(
b
=
TEST_L_MAX
+
1
,
l
=
l
,
m
=
m
,
n
=
n
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
sqnorm_numerical2
=
S3
.
integrate_quad
(
D
*
D
.
conj
(),
grid_type
=
'SOFT'
,
normalize
=
True
,
w
=
w
)
sqnorm_analytical
=
1.
/
(
2
*
l
+
1
)
print
(
l
,
m
,
n
,
field
,
normalization
,
order
,
condon_shortley
,
sqnorm_numerical
,
sqnorm_numerical2
,
sqnorm_analytical
)
assert
np
.
isclose
(
sqnorm_numerical
,
sqnorm_analytical
)
assert
np
.
isclose
(
sqnorm_numerical2
,
sqnorm_analytical
)
def
check_orthogonality_wigner_D
():
"""
According to [1], the Wigner D functions satisfy:
int_0^2pi da int_0^pi db sin(b) int_0^2pi D^l_mn(a,b,c) D^l'_m'n'(a,b,c)*
=
8 pi^2 / (2l+1) delta(ll') delta(mm') delta(nn')
The factor 8 pi^2 is removed if we integrate with respect to the normalized Haar measure.
Here we test this equality by numerical integration.
"""
w
=
S3
.
quadrature_weights
(
b
=
TEST_L_MAX
+
1
,
grid_type
=
'SOFT'
)
for
field
in
(
'real'
,
'complex'
):
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
order
in
(
'centered'
,
'block'
):
for
condon_shortley
in
(
'cs'
,
'nocs'
):
for
l
in
range
(
TEST_L_MAX
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
for
l2
in
range
(
TEST_L_MAX
):
for
m2
in
range
(
-
l2
,
l2
+
1
):
for
n2
in
range
(
-
l2
,
l2
+
1
):
f
=
lambda
a
,
b
,
c
:
\
wigner_D_function
(
l
=
l
,
m
=
m
,
n
=
n
,
alpha
=
a
,
beta
=
b
,
gamma
=
c
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
*
\
wigner_D_function
(
l
=
l2
,
m
=
m2
,
n
=
n2
,
alpha
=
a
,
beta
=
b
,
gamma
=
c
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
).
conj
()
D1
=
make_D_sample_grid
(
b
=
TEST_L_MAX
+
1
,
l
=
l
,
m
=
m
,
n
=
n
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
D2
=
make_D_sample_grid
(
b
=
TEST_L_MAX
+
1
,
l
=
l2
,
m
=
m2
,
n
=
n2
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
numerical_norm2
=
S3
.
integrate_quad
(
D1
*
D2
.
conj
(),
grid_type
=
'SOFT'
,
normalize
=
True
,
w
=
w
)
numerical_norm
=
S3
.
integrate
(
f
,
normalize
=
True
)
analytical_norm
=
((
l
==
l2
)
*
(
m
==
m2
)
*
(
n
==
n2
))
/
(
2
*
l
+
1
)
print
(
field
,
normalization
,
order
,
condon_shortley
,
l
,
m
,
n
,
l2
,
m2
,
n2
,
np
.
round
(
numerical_norm
,
2
),
np
.
round
(
numerical_norm2
,
2
),
np
.
round
(
analytical_norm
,
2
))
assert
np
.
isclose
(
numerical_norm
,
analytical_norm
)
assert
np
.
isclose
(
numerical_norm2
,
analytical_norm
)
def
check_normalization_complex_wigner_d
():
"""
According to [1], the following is true (eq. 12)
int_0^pi |d^l_mn(beta)|^2 sin(beta) dbeta = 1 / (2 l + 1)
NOTE: this function only tests the Wigner-d functions in the *complex basis*.
In this basis, the Wigner-d functions all have the same, simple norm: 2. / (2l + 1)
In the real basis, some functions are identically 0 and for the rest the norm is hard to understand.
We treat these in a separate function below.
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
"""
# The squared L2 norm
# By squared L2 norm of f we mean |f|^2 = int_SO(3) |f(g)|^2 dg, where dg is the normalized Haar measure
L2_norm
=
lambda
l
:
2.
/
(
2
*
l
+
1
)
# Note the factor 2..
for
l
in
range
(
TEST_L_MAX
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
for
field
in
(
'complex'
,):
# Only test complex d functions here
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
condon_shortley
in
(
'cs'
,
'nocs'
):
for
order
in
(
'centered'
,
'block'
):
f
=
lambda
b
:
wigner_d_matrix
(
l
=
l
,
beta
=
b
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)[
l
+
m
,
l
+
n
]
**
2
*
np
.
sin
(
b
)
# from scipy.integrate import quad
# res = quad(f, a=0, b=np.pi, full_output=1)
# val = res[0]
# if not np.isclose(val, L2_norm[normalization](l)):
# print(res)
val
=
myquad
(
f
,
0
,
np
.
pi
)
print
(
l
,
m
,
n
,
field
,
normalization
,
order
,
condon_shortley
,
np
.
round
(
val
,
2
),
np
.
round
(
L2_norm
(
l
),
2
))
assert
np
.
isclose
(
val
,
L2_norm
(
l
))
def
check_orthogonality_complex_wigner_d
():
"""
According to [1], the following is true (eq. 12)
int_0^pi d^l_mn(beta) d^l'_mn(beta) sin(beta) dbeta = 1 / (2 l + 1) delta(l,l')
NOTE: this function only tests the Wigner-d functions in the *complex basis*.
In this basis, the Wigner-d functions all have the same, simple norm: 2. / (2l + 1)
In the real basis, some functions are identically 0 and for the rest the norm is hard to understand.
We treat these in a separate function below.
NOTE: we only test in centered, not the block basis. For some reason this equality fails in the block basis.
I have not investigated the reason for this yet.
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
:return:
"""
# The squared L2 norm for each of the normalizations
# By squared L2 norm of f we mean |f|^2 = int_SO(3) |f(g)|^2 dg, where dg is the normalized Haar measure
L2_norm
=
lambda
l
:
2.
/
(
2
*
l
+
1
)
for
field
in
(
'complex'
,):
for
normalization
in
(
'quantum'
,
'seismology'
,
'geodesy'
,
'nfft'
):
for
condon_shortley
in
(
'cs'
,
'nocs'
):
for
order
in
(
'centered'
,):
# 'block'):
for
m
in
range
(
-
TEST_L_MAX
,
TEST_L_MAX
+
1
):
for
n
in
range
(
-
TEST_L_MAX
,
TEST_L_MAX
+
1
):
for
l
in
range
(
np
.
maximum
(
np
.
abs
(
m
),
np
.
abs
(
n
)),
TEST_L_MAX
):
for
l2
in
range
(
np
.
maximum
(
np
.
abs
(
m
),
np
.
abs
(
n
)),
TEST_L_MAX
):
f
=
lambda
b
:
\
wigner_d_function
(
l
=
l
,
m
=
m
,
n
=
n
,
beta
=
b
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
*
\
wigner_d_function
(
l
=
l2
,
m
=
m
,
n
=
n
,
beta
=
b
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
*
\
np
.
sin
(
b
)
# from scipy.integrate import quad
# res = quad(f, a=0, b=np.pi, full_output=1)
# val = res[0]
# if not np.isclose(val, L2_norm[normalization](l)):
# print(res)
numerical_inner_product
=
myquad
(
f
,
0
,
np
.
pi
)
analytical_inner_product
=
L2_norm
(
l
)
*
(
l
==
l2
)
print
(
l
,
l2
,
m
,
n
,
field
,
normalization
,
order
,
condon_shortley
,
np
.
round
(
numerical_inner_product
,
2
),
np
.
round
(
analytical_inner_product
,
2
))
assert
np
.
isclose
(
numerical_inner_product
,
analytical_inner_product
,
rtol
=
1e-4
,
atol
=
1e-5
)
def
check_orthogonality_naive_wigner_d
():
"""
According to [1], the following is true (eq. 12)
int_0^pi d^l_mn(beta) d^l'_mn(beta) sin(beta) dbeta = 1 / (2 l + 1) delta(l, l')
Here we check this equality numerically for the *naive* implementations of the Wigner-d functions
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
:return:
"""
# The squared L2 norm for each of the normalizations
# By squared L2 norm of f we mean |f|^2 = int_SO(3) |f(g)|^2 dg, where dg is the normalized Haar measure
L2_norm
=
lambda
l
:
2.
/
(
2
*
l
+
1
)
for
m
in
range
(
-
TEST_L_MAX
,
TEST_L_MAX
+
1
):
for
n
in
range
(
-
TEST_L_MAX
,
TEST_L_MAX
+
1
):
for
l
in
range
(
np
.
maximum
(
np
.
abs
(
m
),
np
.
abs
(
n
)),
TEST_L_MAX
):
for
l2
in
range
(
np
.
maximum
(
np
.
abs
(
m
),
np
.
abs
(
n
)),
TEST_L_MAX
):
f1
=
lambda
b
:
\
wigner_d_naive
(
l
=
l
,
m
=
m
,
n
=
n
,
beta
=
b
)
*
\
wigner_d_naive
(
l
=
l2
,
m
=
m
,
n
=
n
,
beta
=
b
)
*
\
np
.
sin
(
b
)
f2
=
lambda
b
:
\
wigner_d_naive_v2
(
l
=
l
,
m
=
m
,
n
=
n
,
beta
=
b
)
*
\
wigner_d_naive_v2
(
l
=
l2
,
m
=
m
,
n
=
n
,
beta
=
b
)
*
\
np
.
sin
(
b
)
f3
=
lambda
b
:
\
wigner_d_naive_v3
(
l
=
l
,
m
=
m
,
n
=
n
)(
b
)
*
\
wigner_d_naive_v3
(
l
=
l2
,
m
=
m
,
n
=
n
)(
b
)
*
\
np
.
sin
(
b
)
for
f
in
(
f1
,
f2
,
f3
):
# from scipy.integrate import quad
# res = quad(f, a=0, b=np.pi, full_output=1)
# val = res[0]
# if not np.isclose(val, L2_norm[normalization](l)):
# print(res)
numerical_inner_product
=
myquad
(
f
,
0
,
np
.
pi
)
analytical_inner_product
=
L2_norm
(
l
)
*
(
l
==
l2
)
print
(
l
,
l2
,
m
,
n
,
np
.
round
(
numerical_inner_product
,
2
),
np
.
round
(
analytical_inner_product
,
2
))
assert
np
.
isclose
(
numerical_inner_product
,
analytical_inner_product
,
rtol
=
1e-4
,
atol
=
1e-5
)
def
myquad
(
f
,
a
,
b
):
n
=
1000
v
=
0.
for
x
in
np
.
linspace
(
a
,
b
,
num
=
n
,
endpoint
=
False
):
v
+=
f
(
x
)
return
v
*
(
b
-
a
)
/
n
# TODO: this test is failing - I'm not sure what the norms for real Wigner-d functions should be (see comments below)
def
check_normalization_wigner_d_real
(
L_max
=
TEST_L_MAX
):
"""
According to [1], the following is true (eq. 12)
int_0^pi d^l_mn(beta) d^l'_mn(beta) sin(beta) dbeta = 1 / (2 l + 1) delta(l, l')
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
:return:
"""
# Note: this function is called "check" not "test" because this function is expensive to evaluate and we don't
# want to automatically call this when running nosetests.
# The squared L2 norm for each of the normalizations
# By L2 norm of f we mean int_SO(3) |f(g)|^2 dg, where dg is the normalized Haar measure
L2_norm
=
{
'quantum'
:
lambda
l
:
2.
/
(
2
*
l
+
1
),
'seismology'
:
lambda
l
:
2.
/
(
2
*
l
+
1
),
'geodesy'
:
lambda
l
:
2.
/
(
2
*
l
+
1
),
'nfft'
:
lambda
l
:
2.
/
(
2
*
l
+
1
)
}
correct
=
[
np
.
zeros
((
2
*
l
+
1
,
2
*
l
+
1
))
for
l
in
range
(
L_max
)]
ratio
=
[
np
.
zeros
((
2
*
l
+
1
,
2
*
l
+
1
))
for
l
in
range
(
L_max
)]
vals
=
[
np
.
zeros
((
2
*
l
+
1
,
2
*
l
+
1
))
for
l
in
range
(
L_max
)]
# Note: this seems to be correct for complex wigners in all normalizations, orders, cs, l, m, n,
# For the real ones, we can understand which wigners are identically zero,
# the norms for the non-zeros appear to be pretty complicated. Plotting the norms for l=9, we see a band pattern
# similar to the appearance of the wigner-d matrix itself.
# This matrix is symmetric, so the norm for dmn equals the norm for dnm
# See note above in check_normalization_wigner_d_complex
# The norm of the real wigner-d functions seems to be hard to understand. The norm now depends on m,n as well as l
# We can understand which real wigners are identically zero (see real_zeros below, or plot a d-matrix).
# Plotting the norms for l=9, we see a moire-like pattern for the non-zero wigners,
# similar in appearance to the wigner-d matrix itself.
# This matrix is symmetric, so the norm for dmn equals the norm for dnm
for
order
in
(
'centered'
,):
# 'block'):
for
l
in
range
(
L_max
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
for
field
in
(
'real'
,):
# only test real here
for
normalization
in
(
'quantum'
,):
# 'seismology', 'geodesy', 'nfft'): all normalization seem to give the same behaviour
for
condon_shortley
in
(
'cs'
,):
# 'nocs'): doesn't seem to matter
f
=
lambda
b
:
wigner_d_matrix
(
l
=
l
,
beta
=
b
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)[
l
+
m
,
l
+
n
]
**
2
*
np
.
sin
(
b
)
# from scipy.integrate import quad
# res = quad(f, a=0, b=np.pi, full_output=1)
# val = res[0]
# if not np.isclose(val, L2_norm[normalization](l)):
# print(res)
val
=
myquad
(
f
,
0
,
np
.
pi
)
real_zeros
=
((
m
<
0
and
n
>=
0
)
or
(
m
>=
0
and
n
<
0
))
and
field
==
'real'
print
(
l
,
m
,
n
,
# field, normalization, order, condon_shortley,
np
.
round
(
val
,
2
),
np
.
round
(
L2_norm
[
normalization
](
l
)
*
(
not
real_zeros
),
2
))
# assert np.isclose(val, L2_norm[normalization](l))
# if not np.isclose(val, L2_norm[normalization](l)):
# print("!!!!!")
correct
[
l
][
l
+
m
,
l
+
n
]
=
np
.
isclose
(
val
,
L2_norm
[
normalization
](
l
)
*
(
not
real_zeros
))
ratio
[
l
][
l
+
m
,
l
+
n
]
=
val
/
(
L2_norm
[
normalization
](
l
)
*
(
not
real_zeros
))
if
(
not
real_zeros
)
else
1
vals
[
l
][
l
+
m
,
l
+
n
]
=
val
return
correct
,
ratio
,
vals
def
make_D_sample_grid
(
b
=
4
,
l
=
0
,
m
=
0
,
n
=
0
,
field
=
'complex'
,
normalization
=
'seismology'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
from
lie_learn.representations.SO3.wigner_d
import
wigner_D_function
D
=
lambda
a
,
b
,
c
:
wigner_D_function
(
l
,
m
,
n
,
alpha
,
beta
,
gamma
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
f
=
np
.
zeros
((
2
*
b
,
2
*
b
,
2
*
b
),
dtype
=
complex
if
field
==
'complex'
else
float
)
for
j1
in
range
(
f
.
shape
[
0
]):
alpha
=
2
*
np
.
pi
*
j1
/
(
2.
*
b
)
for
k
in
range
(
f
.
shape
[
1
]):
beta
=
np
.
pi
*
(
2
*
k
+
1
)
/
(
4.
*
b
)
for
j2
in
range
(
f
.
shape
[
2
]):
gamma
=
2
*
np
.
pi
*
j2
/
(
2.
*
b
)
f
[
j1
,
k
,
j2
]
=
D
(
alpha
,
beta
,
gamma
)
return
f
lie_learn/lie_learn/representations/SO3/wigner_d.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
lie_learn.representations.SO3.pinchon_hoggan.pinchon_hoggan_dense
import
Jd
,
rot_mat
from
lie_learn.representations.SO3.irrep_bases
import
change_of_basis_matrix
def
wigner_d_matrix
(
l
,
beta
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
"""
Compute the Wigner-d matrix of degree l at beta, in the basis defined by
(field, normalization, order, condon_shortley)
The Wigner-d matrix of degree l has shape (2l + 1) x (2l + 1).
:param l: the degree of the Wigner-d function. l >= 0
:param beta: the argument. 0 <= beta <= pi
:param field: 'real' or 'complex'
:param normalization: 'quantum', 'seismology', 'geodesy' or 'nfft'
:param order: 'centered' or 'block'
:param condon_shortley: 'cs' or 'nocs'
:return: d^l_mn(beta) in the chosen basis
"""
# This returns the d matrix in the (real, quantum-normalized, centered, cs) convention
d
=
rot_mat
(
alpha
=
0.
,
beta
=
beta
,
gamma
=
0.
,
l
=
l
,
J
=
Jd
[
l
])
if
(
field
,
normalization
,
order
,
condon_shortley
)
!=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
):
# TODO use change of basis function instead of matrix?
B
=
change_of_basis_matrix
(
l
,
frm
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
),
to
=
(
field
,
normalization
,
order
,
condon_shortley
))
BB
=
change_of_basis_matrix
(
l
,
frm
=
(
field
,
normalization
,
order
,
condon_shortley
),
to
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
))
d
=
B
.
dot
(
d
).
dot
(
BB
)
# The Wigner-d matrices are always real, even in the complex basis
# (I tested this numerically, and have seen it in several texts)
# assert np.isclose(np.sum(np.abs(d.imag)), 0.0)
d
=
d
.
real
return
d
def
wigner_D_matrix
(
l
,
alpha
,
beta
,
gamma
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
"""
Evaluate the Wigner-d matrix D^l_mn(alpha, beta, gamma)
:param l: the degree of the Wigner-d function. l >= 0
:param alpha: the argument. 0 <= alpha <= 2 pi
:param beta: the argument. 0 <= beta <= pi
:param gamma: the argument. 0 <= gamma <= 2 pi
:param field: 'real' or 'complex'
:param normalization: 'quantum', 'seismology', 'geodesy' or 'nfft'
:param order: 'centered' or 'block'
:param condon_shortley: 'cs' or 'nocs'
:return: D^l_mn(alpha, beta, gamma) in the chosen basis
"""
D
=
rot_mat
(
alpha
=
alpha
,
beta
=
beta
,
gamma
=
gamma
,
l
=
l
,
J
=
Jd
[
l
])
if
(
field
,
normalization
,
order
,
condon_shortley
)
!=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
):
B
=
change_of_basis_matrix
(
l
,
frm
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
),
to
=
(
field
,
normalization
,
order
,
condon_shortley
))
BB
=
change_of_basis_matrix
(
l
,
frm
=
(
field
,
normalization
,
order
,
condon_shortley
),
to
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
))
D
=
B
.
dot
(
D
).
dot
(
BB
)
if
field
==
'real'
:
# print('WIGNER D IMAG PART:', np.sum(np.abs(D.imag)))
assert
np
.
isclose
(
np
.
sum
(
np
.
abs
(
D
.
imag
)),
0.0
)
D
=
D
.
real
return
D
def
wigner_d_function
(
l
,
m
,
n
,
beta
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
"""
Evaluate a single Wigner-d function d^l_mn(beta)
NOTE: for now, we implement this by computing the entire degree-l Wigner-d matrix and then selecting
the (m,n) element, so this function is not fast.
:param l: the degree of the Wigner-d function. l >= 0
:param m: the order of the Wigner-d function. -l <= m <= l
:param n: the order of the Wigner-d function. -l <= n <= l
:param beta: the argument. 0 <= beta <= pi
:param field: 'real' or 'complex'
:param normalization: 'quantum', 'seismology', 'geodesy' or 'nfft'
:param order: 'centered' or 'block'
:param condon_shortley: 'cs' or 'nocs'
:return: d^l_mn(beta) in the chosen basis
"""
return
wigner_d_matrix
(
l
,
beta
,
field
,
normalization
,
order
,
condon_shortley
)[
l
+
m
,
l
+
n
]
def
wigner_D_function
(
l
,
m
,
n
,
alpha
,
beta
,
gamma
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
"""
Evaluate a single Wigner-d function d^l_mn(beta)
NOTE: for now, we implement this by computing the entire degree-l Wigner-D matrix and then selecting
the (m,n) element, so this function is not fast.
:param l: the degree of the Wigner-d function. l >= 0
:param m: the order of the Wigner-d function. -l <= m <= l
:param n: the order of the Wigner-d function. -l <= n <= l
:param alpha: the argument. 0 <= alpha <= 2 pi
:param beta: the argument. 0 <= beta <= pi
:param gamma: the argument. 0 <= gamma <= 2 pi
:param field: 'real' or 'complex'
:param normalization: 'quantum', 'seismology', 'geodesy' or 'nfft'
:param order: 'centered' or 'block'
:param condon_shortley: 'cs' or 'nocs'
:return: d^l_mn(beta) in the chosen basis
"""
return
wigner_D_matrix
(
l
,
alpha
,
beta
,
gamma
,
field
,
normalization
,
order
,
condon_shortley
)[
l
+
m
,
l
+
n
]
def
wigner_D_norm
(
l
,
normalized_haar
=
True
):
"""
Compute the squared norm of the Wigner-D functions.
The squared norm of a function on the SO(3) is defined as
|f|^2 = int_SO(3) |f(g)|^2 dg
where dg is a Haar measure.
:param l: for some normalization conventions, the norm of a Wigner-D function D^l_mn depends on the degree l
:param normalized_haar: whether to use the Haar measure da db sinb dc or the normalized Haar measure
da db sinb dc / 8pi^2
:return: the squared norm of the spherical harmonic with respect to given measure
:param l:
:param normalization:
:return:
"""
if
normalized_haar
:
return
1.
/
(
2
*
l
+
1
)
else
:
return
(
8
*
np
.
pi
**
2
)
/
(
2
*
l
+
1
)
def
wigner_d_naive
(
l
,
m
,
n
,
beta
):
"""
Numerically naive implementation of the Wigner-d function.
This is useful for checking the correctness of other implementations.
:param l: the degree of the Wigner-d function. l >= 0
:param m: the order of the Wigner-d function. -l <= m <= l
:param n: the order of the Wigner-d function. -l <= n <= l
:param beta: the argument. 0 <= beta <= pi
:return: d^l_mn(beta) in the TODO: what basis? complex, quantum(?), centered, cs(?)
"""
from
scipy.special
import
eval_jacobi
try
:
from
scipy.misc
import
factorial
except
:
from
scipy.special
import
factorial
from
sympy.functions.special.polynomials
import
jacobi
,
jacobi_normalized
from
sympy.abc
import
j
,
a
,
b
,
x
from
sympy
import
N
#jfun = jacobi_normalized(j, a, b, x)
jfun
=
jacobi
(
j
,
a
,
b
,
x
)
# eval_jacobi = lambda q, r, p, o: float(jfun.eval(int(q), int(r), int(p), float(o)))
# eval_jacobi = lambda q, r, p, o: float(N(jfun, int(q), int(r), int(p), float(o)))
eval_jacobi
=
lambda
q
,
r
,
p
,
o
:
float
(
jfun
.
subs
({
j
:
int
(
q
),
a
:
int
(
r
),
b
:
int
(
p
),
x
:
float
(
o
)}))
mu
=
np
.
abs
(
m
-
n
)
nu
=
np
.
abs
(
m
+
n
)
s
=
l
-
(
mu
+
nu
)
/
2
xi
=
1
if
n
>=
m
else
(
-
1
)
**
(
n
-
m
)
# print(s, mu, nu, np.cos(beta), type(s), type(mu), type(nu), type(np.cos(beta)))
jac
=
eval_jacobi
(
s
,
mu
,
nu
,
np
.
cos
(
beta
))
z
=
np
.
sqrt
((
factorial
(
s
)
*
factorial
(
s
+
mu
+
nu
))
/
(
factorial
(
s
+
mu
)
*
factorial
(
s
+
nu
)))
# print(l, m, n, beta, np.isfinite(mu), np.isfinite(nu), np.isfinite(s), np.isfinite(xi), np.isfinite(jac), np.isfinite(z))
assert
np
.
isfinite
(
mu
)
and
np
.
isfinite
(
nu
)
and
np
.
isfinite
(
s
)
and
np
.
isfinite
(
xi
)
and
np
.
isfinite
(
jac
)
and
np
.
isfinite
(
z
)
assert
np
.
isfinite
(
xi
*
z
*
np
.
sin
(
beta
/
2
)
**
mu
*
np
.
cos
(
beta
/
2
)
**
nu
*
jac
)
return
xi
*
z
*
np
.
sin
(
beta
/
2
)
**
mu
*
np
.
cos
(
beta
/
2
)
**
nu
*
jac
def
wigner_d_naive_v2
(
l
,
m
,
n
,
beta
):
"""
Wigner d functions as defined in the SOFT 2.0 documentation.
When approx_lim is set to a high value, this function appears to give
identical results to Johann Goetz' wignerd() function.
However, integration fails: does not satisfy orthogonality relations everywhere...
"""
from
scipy.special
import
jacobi
if
n
>=
m
:
xi
=
1
else
:
xi
=
(
-
1
)
**
(
n
-
m
)
mu
=
np
.
abs
(
m
-
n
)
nu
=
np
.
abs
(
n
+
m
)
s
=
l
-
(
mu
+
nu
)
*
0.5
sq
=
np
.
sqrt
((
np
.
math
.
factorial
(
s
)
*
np
.
math
.
factorial
(
s
+
mu
+
nu
))
/
(
np
.
math
.
factorial
(
s
+
mu
)
*
np
.
math
.
factorial
(
s
+
nu
)))
sinb
=
np
.
sin
(
beta
*
0.5
)
**
mu
cosb
=
np
.
cos
(
beta
*
0.5
)
**
nu
P
=
jacobi
(
s
,
mu
,
nu
)(
np
.
cos
(
beta
))
return
xi
*
sq
*
sinb
*
cosb
*
P
def
wigner_d_naive_v3
(
l
,
m
,
n
,
approx_lim
=
1000000
):
"""
Wigner "small d" matrix. (Euler z-y-z convention)
example:
l = 2
m = 1
n = 0
beta = linspace(0,pi,100)
wd210 = wignerd(l,m,n)(beta)
some conditions have to be met:
l >= 0
-l <= m <= l
-l <= n <= l
The approx_lim determines at what point
bessel functions are used. Default is when:
l > m+10
and
l > n+10
for integer l and n=0, we can use the spherical harmonics. If in
addition m=0, we can use the ordinary legendre polynomials.
"""
from
scipy.special
import
jv
,
legendre
,
sph_harm
,
jacobi
try
:
from
scipy.misc
import
factorial
,
comb
except
:
from
scipy.special
import
factorial
,
comb
from
numpy
import
floor
,
sqrt
,
sin
,
cos
,
exp
,
power
from
math
import
pi
from
scipy.special
import
jacobi
if
(
l
<
0
)
or
(
abs
(
m
)
>
l
)
or
(
abs
(
n
)
>
l
):
raise
ValueError
(
"wignerd(l = {0}, m = {1}, n = {2}) value error."
.
format
(
l
,
m
,
n
)
\
+
" Valid range for parameters: l>=0, -l<=m,n<=l."
)
if
(
l
>
(
m
+
approx_lim
))
and
(
l
>
(
n
+
approx_lim
)):
#print 'bessel (approximation)'
return
lambda
beta
:
jv
(
m
-
n
,
l
*
beta
)
if
(
floor
(
l
)
==
l
)
and
(
n
==
0
):
if
m
==
0
:
#print 'legendre (exact)'
return
lambda
beta
:
legendre
(
l
)(
cos
(
beta
))
elif
False
:
#print 'spherical harmonics (exact)'
a
=
sqrt
(
4.
*
pi
/
(
2.
*
l
+
1.
))
return
lambda
beta
:
a
*
sph_harm
(
m
,
l
,
beta
,
0.
).
conj
()
jmn_terms
=
{
l
+
n
:
(
m
-
n
,
m
-
n
),
l
-
n
:
(
n
-
m
,
0.
),
l
+
m
:
(
n
-
m
,
0.
),
l
-
m
:
(
m
-
n
,
m
-
n
),
}
k
=
min
(
jmn_terms
)
a
,
lmb
=
jmn_terms
[
k
]
b
=
2.
*
l
-
2.
*
k
-
a
if
(
a
<
0
)
or
(
b
<
0
):
raise
ValueError
(
"wignerd(l = {0}, m = {1}, n = {2}) value error."
.
format
(
l
,
m
,
n
)
\
+
" Encountered negative values in (a,b) = ({0},{1})"
.
format
(
a
,
b
))
coeff
=
power
(
-
1.
,
lmb
)
*
sqrt
(
comb
(
2.
*
l
-
k
,
k
+
a
))
*
(
1.
/
sqrt
(
comb
(
k
+
b
,
b
)))
#print 'jacobi (exact)'
return
lambda
beta
:
coeff
\
*
power
(
sin
(
0.5
*
beta
),
a
)
\
*
power
(
cos
(
0.5
*
beta
),
b
)
\
*
jacobi
(
k
,
a
,
b
)(
cos
(
beta
))
lie_learn/lie_learn/representations/__init__.py
0 → 100755
View file @
b5881ee2
lie_learn/lie_learn/spaces/S2.py
0 → 100755
View file @
b5881ee2
"""
The 2-sphere, S^2
"""
import
numpy
as
np
from
numpy.polynomial.legendre
import
leggauss
def
change_coordinates
(
coords
,
p_from
=
'C'
,
p_to
=
'S'
):
"""
Change Spherical to Cartesian coordinates and vice versa, for points x in S^2.
In the spherical system, we have coordinates beta and alpha,
where beta in [0, pi] and alpha in [0, 2pi]
We use the names beta and alpha for compatibility with the SO(3) code (S^2 being a quotient SO(3)/SO(2)).
Many sources, like wikipedia use theta=beta and phi=alpha.
:param coords: coordinate array
:param p_from: 'C' for Cartesian or 'S' for spherical coordinates
:param p_to: 'C' for Cartesian or 'S' for spherical coordinates
:return: new coordinates
"""
if
p_from
==
p_to
:
return
coords
elif
p_from
==
'S'
and
p_to
==
'C'
:
beta
=
coords
[...,
0
]
alpha
=
coords
[...,
1
]
r
=
1.
out
=
np
.
empty
(
beta
.
shape
+
(
3
,))
ct
=
np
.
cos
(
beta
)
cp
=
np
.
cos
(
alpha
)
st
=
np
.
sin
(
beta
)
sp
=
np
.
sin
(
alpha
)
out
[...,
0
]
=
r
*
st
*
cp
# x
out
[...,
1
]
=
r
*
st
*
sp
# y
out
[...,
2
]
=
r
*
ct
# z
return
out
elif
p_from
==
'C'
and
p_to
==
'S'
:
x
=
coords
[...,
0
]
y
=
coords
[...,
1
]
z
=
coords
[...,
2
]
out
=
np
.
empty
(
x
.
shape
+
(
2
,))
out
[...,
0
]
=
np
.
arccos
(
z
)
# beta
out
[...,
1
]
=
np
.
arctan2
(
y
,
x
)
# alpha
return
out
else
:
raise
ValueError
(
'Unknown conversion:'
+
str
(
p_from
)
+
' to '
+
str
(
p_to
))
def
meshgrid
(
b
,
grid_type
=
'Driscoll-Healy'
):
"""
Create a coordinate grid for the 2-sphere.
There are various ways to setup a grid on the sphere.
if grid_type == 'Driscoll-Healy', we follow the grid_type from [4], which is also used in [5]:
beta_j = pi j / (2 b) for j = 0, ..., 2b - 1
alpha_k = pi k / b for k = 0, ..., 2b - 1
if grid_type == 'SOFT', we follow the grid_type from [1] and [6]
beta_j = pi (2 j + 1) / (4 b) for j = 0, ..., 2b - 1
alpha_k = pi k / b for k = 0, ..., 2b - 1
if grid_type == 'Clenshaw-Curtis', we use the Clenshaw-Curtis grid, as defined in [2] (section 6):
beta_j = j pi / (2b) for j = 0, ..., 2b
alpha_k = k pi / (b + 1) for k = 0, ..., 2b + 1
if grid_type == 'Gauss-Legendre', we use the Gauss-Legendre grid, as defined in [2] (section 6) and [7] (eq. 2):
beta_j = the Gauss-Legendre nodes for j = 0, ..., b
alpha_k = k pi / (b + 1), for k = 0, ..., 2 b + 1
if grid_type == 'HEALPix', we use the HEALPix grid, see [2] (section 6):
TODO
if grid_type == 'equidistribution', we use the equidistribution grid, as defined in [2] (section 6):
TODO
[1] SOFT: SO(3) Fourier Transforms
Kostelec, Peter J & Rockmore, Daniel N.
[2] Fast evaluation of quadrature formulae on the sphere
Jens Keiner, Daniel Potts
[3] A Fast Algorithm for Spherical Grid Rotations and its Application to Singular Quadrature
Zydrunas Gimbutas Shravan Veerapaneni
[4] Computing Fourier transforms and convolutions on the 2-sphere
Driscoll, JR & Healy, DM
[5] Engineering Applications of Noncommutative Harmonic Analysis
Chrikjian, G.S. & Kyatkin, A.B.
[6] FFTs for the 2-Sphere – Improvements and Variations
Healy, D., Rockmore, D., Kostelec, P., Moore, S
[7] A Fast Algorithm for Spherical Grid Rotations and its Application to Singular Quadrature
Zydrunas Gimbutas, Shravan Veerapaneni
:param b: the bandwidth / resolution
:return: a meshgrid on S^2
"""
return
np
.
meshgrid
(
*
linspace
(
b
,
grid_type
),
indexing
=
'ij'
)
def
linspace
(
b
,
grid_type
=
'Driscoll-Healy'
):
if
grid_type
==
'Driscoll-Healy'
:
beta
=
np
.
arange
(
2
*
b
)
*
np
.
pi
/
(
2.
*
b
)
alpha
=
np
.
arange
(
2
*
b
)
*
np
.
pi
/
b
elif
grid_type
==
'SOFT'
:
beta
=
np
.
pi
*
(
2
*
np
.
arange
(
2
*
b
)
+
1
)
/
(
4.
*
b
)
alpha
=
np
.
arange
(
2
*
b
)
*
np
.
pi
/
b
elif
grid_type
==
'Clenshaw-Curtis'
:
# beta = np.arange(2 * b + 1) * np.pi / (2 * b)
# alpha = np.arange(2 * b + 2) * np.pi / (b + 1)
# Must use np.linspace to prevent numerical errors that cause beta > pi
beta
=
np
.
linspace
(
0
,
np
.
pi
,
2
*
b
+
1
)
alpha
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
2
*
b
+
2
,
endpoint
=
False
)
elif
grid_type
==
'Gauss-Legendre'
:
x
,
_
=
leggauss
(
b
+
1
)
# TODO: leggauss docs state that this may not be only stable for orders > 100
beta
=
np
.
arccos
(
x
)
alpha
=
np
.
arange
(
2
*
b
+
2
)
*
np
.
pi
/
(
b
+
1
)
elif
grid_type
==
'HEALPix'
:
#TODO: implement this here so that we don't need the dependency on healpy / healpix_compat
from
healpix_compat
import
healpy_sphere_meshgrid
return
healpy_sphere_meshgrid
(
b
)
elif
grid_type
==
'equidistribution'
:
raise
NotImplementedError
(
'Not implemented yet; see Fast evaluation of quadrature formulae on the sphere.'
)
else
:
raise
ValueError
(
'Unknown grid_type:'
+
grid_type
)
return
beta
,
alpha
def
quadrature_weights
(
b
,
grid_type
=
'Gauss-Legendre'
):
"""
Compute quadrature weights for a given grid-type.
The function S2.meshgrid generates the points that correspond to the weights generated by this function.
if convention == 'Gauss-Legendre':
The quadrature formula is exact for polynomials up to degree M less than or equal to 2b + 1,
so that we can compute exact Fourier coefficients for f a polynomial of degree at most b.
if convention == 'Clenshaw-Curtis':
The quadrature formula is exact for polynomials up to degree M less than or equal to 2b,
so that we can compute exact Fourier coefficients for f a polynomial of degree at most b.
:param b: the grid resolution. See S2.meshgrid
:param grid_type:
:return:
"""
if
grid_type
==
'Clenshaw-Curtis'
:
# There is a faster fft based method to compute these weights
# see "Fast evaluation of quadrature formulae on the sphere"
# W = np.empty((2 * b + 2, 2 * b + 1))
# for j in range(2 * b + 1):
# eps_j_2b = 0.5 if j == 0 or j == 2 * b else 1.
# for k in range(2 * b + 2): # Doesn't seem to depend on k..
# W[k, j] = (4 * np.pi * eps_j_2b) / (b * (2 * b + 2))
# sum = 0.
# for l in range(b + 1):
# eps_l_b = 0.5 if l == 0 or l == b else 1.
# sum += eps_l_b / (1 - 4 * l ** 2) * np.cos(j * l * np.pi / b)
# W[k, j] *= sum
w
=
_clenshaw_curtis_weights
(
n
=
2
*
b
)
W
=
np
.
empty
((
2
*
b
+
1
,
2
*
b
+
2
))
W
[:]
=
w
[:,
None
]
elif
grid_type
==
'Gauss-Legendre'
:
# We found this formula in:
# "A Fast Algorithm for Spherical Grid Rotations and its Application to Singular Quadrature"
# eq. 10
_
,
w
=
leggauss
(
b
+
1
)
W
=
w
[:,
None
]
*
(
2
*
np
.
pi
/
(
2
*
b
+
2
)
*
np
.
ones
(
2
*
b
+
2
)[
None
,
:])
elif
grid_type
==
'SOFT'
:
print
(
"WARNING: SOFT quadrature weights don't work yet"
)
k
=
np
.
arange
(
0
,
b
)
w
=
np
.
array
([(
2.
/
b
)
*
np
.
sin
(
np
.
pi
*
(
2.
*
j
+
1.
)
/
(
4.
*
b
))
*
(
np
.
sum
((
1.
/
(
2
*
k
+
1
))
*
np
.
sin
((
2
*
j
+
1
)
*
(
2
*
k
+
1
)
*
np
.
pi
/
(
4.
*
b
))))
for
j
in
range
(
2
*
b
)])
W
=
w
[:,
None
]
*
np
.
ones
(
2
*
b
)[
None
,
:]
else
:
raise
ValueError
(
'Unknown grid_type:'
+
str
(
grid_type
))
return
W
def
integrate
(
f
,
normalize
=
True
):
"""
Integrate a function f : S^2 -> R over the sphere S^2, using the invariant integration measure
mu((beta, alpha)) = sin(beta) dbeta dalpha
i.e. this returns
int_S^2 f(x) dmu(x) = int_0^2pi int_0^pi f(beta, alpha) sin(beta) dbeta dalpha
:param f: a function of two scalar variables returning a scalar.
:return: the integral of f over the 2-sphere
"""
from
scipy.integrate
import
quad
f2
=
lambda
alpha
:
quad
(
lambda
beta
:
f
(
beta
,
alpha
)
*
np
.
sin
(
beta
),
a
=
0
,
b
=
np
.
pi
)[
0
]
integral
=
quad
(
f2
,
0
,
2
*
np
.
pi
)[
0
]
if
normalize
:
return
integral
/
(
4
*
np
.
pi
)
else
:
return
integral
def
integrate_quad
(
f
,
grid_type
,
normalize
=
True
,
w
=
None
):
"""
Integrate a function f : S^2 -> R, sampled on a grid of type grid_type, using quadrature weights w.
:param f: an ndarray containing function values on a grid
:param grid_type: the type of grid used to sample f
:param normalize: whether to use the normalized Haar measure or not
:param w: the quadrature weights. If not given, they are computed.
:return: the integral of f over S^2.
"""
if
grid_type
!=
'Gauss-Legendre'
and
grid_type
!=
'Clenshaw-Curtis'
:
raise
NotImplementedError
b
=
(
f
.
shape
[
1
]
-
2
)
//
2
# This works for Gauss-Legendre and Clenshaw-Curtis
if
w
is
None
:
w
=
quadrature_weights
(
b
,
grid_type
)
integral
=
np
.
sum
(
f
*
w
)
if
normalize
:
return
integral
/
(
4
*
np
.
pi
)
else
:
return
integral
def
plot_sphere_func
(
f
,
grid
=
'Clenshaw-Curtis'
,
beta
=
None
,
alpha
=
None
,
colormap
=
'jet'
,
fignum
=
0
,
normalize
=
True
):
#TODO: All grids except Clenshaw-Curtis have holes at the poles
# TODO: update this function now that we changed the order of axes in f
import
matplotlib
matplotlib
.
use
(
'WxAgg'
)
matplotlib
.
interactive
(
True
)
from
mayavi
import
mlab
if
normalize
:
f
=
(
f
-
np
.
min
(
f
))
/
(
np
.
max
(
f
)
-
np
.
min
(
f
))
if
grid
==
'Driscoll-Healy'
:
b
=
f
.
shape
[
0
]
/
2
elif
grid
==
'Clenshaw-Curtis'
:
b
=
(
f
.
shape
[
0
]
-
2
)
/
2
elif
grid
==
'SOFT'
:
b
=
f
.
shape
[
0
]
/
2
elif
grid
==
'Gauss-Legendre'
:
b
=
(
f
.
shape
[
0
]
-
2
)
/
2
if
beta
is
None
or
alpha
is
None
:
beta
,
alpha
=
meshgrid
(
b
=
b
,
grid_type
=
grid
)
alpha
=
np
.
r_
[
alpha
,
alpha
[
0
,
:][
None
,
:]]
beta
=
np
.
r_
[
beta
,
beta
[
0
,
:][
None
,
:]]
f
=
np
.
r_
[
f
,
f
[
0
,
:][
None
,
:]]
x
=
np
.
sin
(
beta
)
*
np
.
cos
(
alpha
)
y
=
np
.
sin
(
beta
)
*
np
.
sin
(
alpha
)
z
=
np
.
cos
(
beta
)
mlab
.
figure
(
fignum
,
bgcolor
=
(
1
,
1
,
1
),
fgcolor
=
(
0
,
0
,
0
),
size
=
(
600
,
400
))
mlab
.
clf
()
mlab
.
mesh
(
x
,
y
,
z
,
scalars
=
f
,
colormap
=
colormap
)
#mlab.view(90, 70, 6.2, (-1.3, -2.9, 0.25))
mlab
.
show
()
def
plot_sphere_func2
(
f
,
grid
=
'Clenshaw-Curtis'
,
beta
=
None
,
alpha
=
None
,
colormap
=
'jet'
,
fignum
=
0
,
normalize
=
True
):
# TODO: update this function now that we have changed the order of axes in f
import
matplotlib.pyplot
as
plt
from
matplotlib
import
cm
,
colors
from
mpl_toolkits.mplot3d
import
Axes3D
import
numpy
as
np
from
scipy.special
import
sph_harm
if
normalize
:
f
=
(
f
-
np
.
min
(
f
))
/
(
np
.
max
(
f
)
-
np
.
min
(
f
))
if
grid
==
'Driscoll-Healy'
:
b
=
f
.
shape
[
0
]
//
2
elif
grid
==
'Clenshaw-Curtis'
:
b
=
(
f
.
shape
[
0
]
-
2
)
//
2
elif
grid
==
'SOFT'
:
b
=
f
.
shape
[
0
]
//
2
elif
grid
==
'Gauss-Legendre'
:
b
=
(
f
.
shape
[
0
]
-
2
)
//
2
if
beta
is
None
or
alpha
is
None
:
beta
,
alpha
=
meshgrid
(
b
=
b
,
grid_type
=
grid
)
alpha
=
np
.
r_
[
alpha
,
alpha
[
0
,
:][
None
,
:]]
beta
=
np
.
r_
[
beta
,
beta
[
0
,
:][
None
,
:]]
f
=
np
.
r_
[
f
,
f
[
0
,
:][
None
,
:]]
x
=
np
.
sin
(
beta
)
*
np
.
cos
(
alpha
)
y
=
np
.
sin
(
beta
)
*
np
.
sin
(
alpha
)
z
=
np
.
cos
(
beta
)
# m, l = 2, 3
# Calculate the spherical harmonic Y(l,m) and normalize to [0,1]
# fcolors = sph_harm(m, l, beta, alpha).real
# fmax, fmin = fcolors.max(), fcolors.min()
# fcolors = (fcolors - fmin) / (fmax - fmin)
print
(
x
.
shape
,
f
.
shape
)
if
f
.
ndim
==
2
:
f
=
cm
.
gray
(
f
)
print
(
'2'
)
# Set the aspect ratio to 1 so our sphere looks spherical
fig
=
plt
.
figure
(
figsize
=
plt
.
figaspect
(
1.
))
ax
=
fig
.
add_subplot
(
111
,
projection
=
'3d'
)
ax
.
plot_surface
(
x
,
y
,
z
,
rstride
=
1
,
cstride
=
1
,
facecolors
=
f
)
# cm.gray(f))
# Turn off the axis planes
ax
.
set_axis_off
()
plt
.
show
()
def
_clenshaw_curtis_weights
(
n
):
"""
Computes the Clenshaw-Curtis quadrature using a fast FFT method.
This is a 'brainless' port of MATLAB code found in:
Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules
Jorg Waldvogel, 2005
http://www.sam.math.ethz.ch/~joergw/Papers/fejer.pdf
:param n:
:return:
"""
from
scipy.fftpack
import
ifft
,
fft
,
fftshift
# TODO python3 handles division differently from python2. Check how MATLAB interprets /, and if this code is still correct for python3
# function [wf1,wf2,wcc] = fejer(n)
# Weights of the Fejer2, Clenshaw-Curtis and Fejer1 quadratures by DFTs
# n>1. Nodes: x_k = cos(k*pi/n)
# N = [1:2:n-1]'; l=length(N); m=n-l; K=[0:m-1]';
N
=
np
.
arange
(
start
=
1
,
stop
=
n
,
step
=
2
)[:,
None
]
l
=
N
.
size
m
=
n
-
l
K
=
np
.
arange
(
start
=
0
,
stop
=
m
)[:,
None
]
# Fejer2 nodes: k=0,1,...,n; weights: wf2, wf2_n=wf2_0=0
# v0 = [2./N./(N-2); 1/N(end); zeros(m,1)];
v0
=
np
.
vstack
([
2.
/
N
/
(
N
-
2
),
1.
/
N
[
-
1
]]
+
[
0
]
*
m
)
# v2 = -v0(1:end-1) - v0(end:-1:2);
# wf2 = ifft(v2);
v2
=
-
v0
[:
-
1
]
-
v0
[:
0
:
-
1
]
# Clenshaw-Curtis nodes: k=0,1,...,n; weights: wcc, wcc_n=wcc_0
# g0 = -ones(n,1);
g0
=
-
np
.
ones
((
n
,
1
))
# g0(1 + l) = g0(1 + l) + n;
g0
[
l
]
=
g0
[
l
]
+
n
# g0(1+m) = g0(1 + m) + n;
g0
[
m
]
=
g0
[
m
]
+
n
# g = g0/(n^2-1+mod(n,2));
g
=
g0
/
(
n
**
2
-
1
+
n
%
2
)
# wcc=ifft(v2 + g);
wcc
=
ifft
((
v2
+
g
).
flatten
()).
real
wcc
=
np
.
hstack
([
wcc
,
wcc
[
0
]])
# Fejer1 nodes: k=1/2,3/2,...,n-1/2; vector of weights: wf1
# v0=[2*exp(i*pi*K/n)./(1-4*K.^2); zeros(l+1,1)];
# v1=v0(1:end-1)+conj(v0(end:-1:2)); wf1=ifft(v1);
# don't need these
return
wcc
*
np
.
pi
/
(
n
/
2
+
1
)
# adjust for different scaling of python vs MATLAB fft
lie_learn/lie_learn/spaces/S3.py
0 → 100755
View file @
b5881ee2
from
functools
import
lru_cache
import
numpy
as
np
import
lie_learn.spaces.S2
as
S2
def
change_coordinates
(
coords
,
p_from
=
'C'
,
p_to
=
'S'
):
"""
Change Spherical to Cartesian coordinates and vice versa, for points x in S^3.
We use the following coordinate system:
https://en.wikipedia.org/wiki/N-sphere#Spherical_coordinates
Except that we use the order (alpha, beta, gamma), where beta ranges from 0 to pi while alpha and gamma range from
0 to 2 pi.
x0 = r * cos(alpha)
x1 = r * sin(alpha) * cos(gamma)
x2 = r * sin(alpha) * sin(gamma) * cos(beta)
x3 = r * sin(alpha * sin(gamma) * sin(beta)
:param conversion:
:param coords:
:return:
"""
if
p_from
==
p_to
:
return
coords
elif
p_from
==
'S'
and
p_to
==
'C'
:
alpha
=
coords
[...,
0
]
beta
=
coords
[...,
1
]
gamma
=
coords
[...,
2
]
r
=
1.
out
=
np
.
empty
(
alpha
.
shape
+
(
4
,))
ca
=
np
.
cos
(
alpha
)
cb
=
np
.
cos
(
beta
)
cc
=
np
.
cos
(
gamma
)
sa
=
np
.
sin
(
alpha
)
sb
=
np
.
sin
(
beta
)
sc
=
np
.
sin
(
gamma
)
out
[...,
0
]
=
r
*
ca
out
[...,
1
]
=
r
*
sa
*
cc
out
[...,
2
]
=
r
*
sa
*
sc
*
cb
out
[...,
3
]
=
r
*
sa
*
sc
*
sb
return
out
elif
p_from
==
'C'
and
p_to
==
'S'
:
raise
NotImplementedError
x
=
coords
[...,
0
]
y
=
coords
[...,
1
]
z
=
coords
[...,
2
]
w
=
coords
[...,
3
]
r
=
np
.
sqrt
((
coords
**
2
).
sum
(
axis
=-
1
))
out
=
np
.
empty
(
x
.
shape
+
(
3
,))
out
[...,
0
]
=
np
.
arccos
(
z
)
# alpha
out
[...,
1
]
=
np
.
arctan2
(
y
,
x
)
# beta
out
[...,
2
]
=
np
.
arctan2
(
y
,
x
)
# gamma
return
out
else
:
raise
ValueError
(
'Unknown conversion:'
+
str
(
p_from
)
+
' to '
+
str
(
p_to
))
def
linspace
(
b
,
grid_type
=
'SOFT'
):
"""
Compute a linspace on the 3-sphere.
Since S3 is ismorphic to SO(3), we use the grid grid_type from:
FFTs on the Rotation Group
Peter J. Kostelec and Daniel N. Rockmore
http://www.cs.dartmouth.edu/~geelong/soft/03-11-060.pdf
:param b:
:return:
"""
# alpha = 2 * np.pi * np.arange(2 * b) / (2. * b)
# beta = np.pi * (2 * np.arange(2 * b) + 1) / (4. * b)
# gamma = 2 * np.pi * np.arange(2 * b) / (2. * b)
beta
,
alpha
=
S2
.
linspace
(
b
,
grid_type
)
# According to this paper:
# "Sampling sets and quadrature formulae on the rotation group"
# We can just tack a sampling grid for S^1 to a sampling grid for S^2 to get a sampling grid for SO(3).
gamma
=
2
*
np
.
pi
*
np
.
arange
(
2
*
b
)
/
(
2.
*
b
)
return
alpha
,
beta
,
gamma
def
meshgrid
(
b
,
grid_type
=
'SOFT'
):
return
np
.
meshgrid
(
*
linspace
(
b
,
grid_type
),
indexing
=
'ij'
)
def
integrate
(
f
,
normalize
=
True
):
"""
Integrate a function f : S^3 -> R over the 3-sphere S^3, using the invariant integration measure
mu((alpha, beta, gamma)) = dalpha sin(beta) dbeta dgamma
i.e. this returns
int_S^3 f(x) dmu(x) = int_0^2pi int_0^pi int_0^2pi f(alpha, beta, gamma) dalpha sin(beta) dbeta dgamma
:param f: a function of three scalar variables returning a scalar.
:param normalize: if we use the measure dalpha sin(beta) dbeta dgamma,
the integral of f(a,b,c)=1 over the 3-sphere gives 8 pi^2.
If normalize=True, we divide the result of integration by this normalization constant, so that f integrates to 1.
In other words, use the normalized Haar measure.
:return: the integral of f over the 3-sphere
"""
from
scipy.integrate
import
quad
f2
=
lambda
alpha
,
gamma
:
quad
(
lambda
beta
:
f
(
alpha
,
beta
,
gamma
)
*
np
.
sin
(
beta
),
a
=
0
,
b
=
np
.
pi
)[
0
]
f3
=
lambda
alpha
:
quad
(
lambda
gamma
:
f2
(
alpha
,
gamma
),
a
=
0
,
b
=
2
*
np
.
pi
)[
0
]
integral
=
quad
(
f3
,
0
,
2
*
np
.
pi
)[
0
]
if
normalize
:
return
integral
/
(
8
*
np
.
pi
**
2
)
else
:
return
integral
def
integrate_quad
(
f
,
grid_type
,
normalize
=
True
,
w
=
None
):
"""
Integrate a function f : SO(3) -> R, sampled on a grid of type grid_type, using quadrature weights w.
:param f: an ndarray containing function values on a grid
:param grid_type: the type of grid used to sample f
:param normalize: whether to use the normalized Haar measure or not
:param w: the quadrature weights. If not given, they are computed.
:return: the integral of f over S^2.
"""
if
grid_type
==
'SOFT'
:
b
=
f
.
shape
[
0
]
//
2
if
w
is
None
:
w
=
quadrature_weights
(
b
,
grid_type
)
integral
=
np
.
sum
(
f
*
w
[
None
,
:,
None
])
else
:
raise
NotImplementedError
(
'Unsupported grid_type:'
,
grid_type
)
if
normalize
:
return
integral
else
:
return
integral
*
8
*
np
.
pi
**
2
@
lru_cache
(
maxsize
=
32
)
def
quadrature_weights
(
b
,
grid_type
=
'SOFT'
):
"""
Compute quadrature weights for the grid used by Kostelec & Rockmore [1, 2].
This grid is:
alpha = 2 pi i / 2b
beta = pi (2 j + 1) / 4b
gamma = 2 pi k / 2b
where 0 <= i, j, k < 2b are indices
This grid can be obtained from the function: S3.linspace or S3.meshgrid
The quadrature weights for this grid are
w_B(j) = 2/b * sin(pi(2j + 1) / 4b) * sum_{k=0}^{b-1} 1 / (2 k + 1) sin((2j + 1)(2k + 1) pi / 4b)
This is eq. 23 in [1] and eq. 2.15 in [2].
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
[2] FFTs on the Rotation Group
Peter J. Kostelec · Daniel N. Rockmore
:param b: bandwidth (grid has shape 2b * 2b * 2b)
:return: w: an array of length 2b containing the quadrature weigths
"""
if
grid_type
==
'SOFT'
:
k
=
np
.
arange
(
0
,
b
)
w
=
np
.
array
([(
2.
/
b
)
*
np
.
sin
(
np
.
pi
*
(
2.
*
j
+
1.
)
/
(
4.
*
b
))
*
(
np
.
sum
((
1.
/
(
2
*
k
+
1
))
*
np
.
sin
((
2
*
j
+
1
)
*
(
2
*
k
+
1
)
*
np
.
pi
/
(
4.
*
b
))))
for
j
in
range
(
2
*
b
)])
# This is not in the SOFT documentation, but we found that it is necessary to divide by this factor to
# get correct results.
w
/=
2.
*
((
2
*
b
)
**
2
)
# In the SOFT source, they talk about the following weights being used for
# odd-order transforms. Do not understand this, and the weights used above
# (defined in the SOFT papers) seems to work.
# w = np.array([(2. / b) *
# (np.sum((1. / (2 * k + 1))
# * np.sin((2 * j + 1) * (2 * k + 1)
# * np.pi / (4. * b))))
# for j in range(2 * b)])
return
w
else
:
raise
NotImplementedError
\ No newline at end of file
lie_learn/lie_learn/spaces/Tn.py
0 → 100755
View file @
b5881ee2
"""
The n-Torus
"""
import
numpy
as
np
def
linspace
(
b
,
n
=
1
,
convention
=
'regular'
):
if
convention
==
'regular'
:
res
=
[]
for
i
in
range
(
n
):
res
.
append
(
np
.
arange
(
b
)
*
2
*
np
.
pi
/
b
)
else
:
raise
ValueError
(
'Unknown convention:'
+
convention
)
return
res
\ No newline at end of file
lie_learn/lie_learn/spaces/__init__.py
0 → 100755
View file @
b5881ee2
__author__
=
'tsc'
lie_learn/lie_learn/spaces/rn.py
0 → 100755
View file @
b5881ee2
"""
n-dimensional real space, R^n.
"""
import
numpy
as
np
# The following functions are part of the public interface of this module;
# other spaces / groups define their own meshgrid and linspace functions that work in an analogous way;
# for R^n the standard numpy functions fulfill this role.
from
numpy
import
meshgrid
,
linspace
def
change_coordinates
(
coords
,
n
,
p_from
=
'C'
,
p_to
=
'S'
):
"""
Change Spherical to Cartesian coordinates and vice versa.
todo: make this work for R^n and not just R^2, R^3
:param conversion:
:param coords:
:return:
"""
coords
=
np
.
asarray
(
coords
)
if
p_from
==
p_to
:
return
coords
if
n
==
2
:
if
(
p_from
==
'P'
or
p_from
==
'polar'
)
and
(
p_to
==
'C'
or
p_to
==
'cartesian'
):
r
=
coords
[...,
0
]
theta
=
coords
[...,
1
]
out
=
np
.
empty_like
(
coords
)
out
[...,
0
]
=
r
*
np
.
cos
(
theta
)
out
[...,
1
]
=
r
*
np
.
sin
(
theta
)
return
out
elif
(
p_from
==
'C'
or
p_from
==
'cartesian'
)
and
(
p_to
==
'P'
or
p_to
==
'polar'
):
x
=
coords
[...,
0
]
y
=
coords
[...,
1
]
out
=
np
.
empty_like
(
coords
)
out
[...,
0
]
=
np
.
sqrt
(
x
**
2
+
y
**
2
)
out
[...,
1
]
=
np
.
arctan2
(
y
,
x
)
return
out
elif
(
p_from
==
'C'
or
p_from
==
'cartesian'
)
and
(
p_to
==
'H'
or
p_to
==
'homogeneous'
):
x
=
coords
[...,
0
]
y
=
coords
[...,
1
]
out
=
np
.
empty
(
coords
.
shape
[:
-
1
]
+
(
3
,))
out
[...,
0
]
=
x
out
[...,
1
]
=
y
out
[...,
2
]
=
1.
return
out
elif
(
p_from
==
'H'
or
p_from
==
'homogeneous'
)
and
(
p_to
==
'C'
or
p_to
==
'cartesian'
):
xc
=
coords
[...,
0
]
yc
=
coords
[...,
1
]
c
=
coords
[...,
2
]
out
=
np
.
empty
(
coords
.
shape
[:
-
1
]
+
(
2
,))
out
[...,
0
]
=
xc
/
c
out
[...,
1
]
=
yc
/
c
return
out
else
:
raise
ValueError
(
'Unknown conversion'
+
str
(
p_from
)
+
' to '
+
str
(
p_to
))
elif
n
==
3
:
if
p_from
==
'S'
and
p_to
==
'C'
:
theta
=
coords
[...,
0
]
phi
=
coords
[...,
1
]
r
=
coords
[...,
2
]
out
=
np
.
empty
(
theta
.
shape
+
(
3
,))
ct
=
np
.
cos
(
theta
)
cp
=
np
.
cos
(
phi
)
st
=
np
.
sin
(
theta
)
sp
=
np
.
sin
(
phi
)
out
[...,
0
]
=
r
*
st
*
cp
# x
out
[...,
1
]
=
r
*
st
*
sp
# y
out
[...,
2
]
=
r
*
ct
# z
return
out
elif
p_from
==
'C'
and
p_to
==
'S'
:
x
=
coords
[...,
0
]
y
=
coords
[...,
1
]
z
=
coords
[...,
2
]
out
=
np
.
empty_like
(
coords
)
out
[...,
2
]
=
np
.
sqrt
(
x
**
2
+
y
**
2
+
z
**
2
)
# r
out
[...,
0
]
=
np
.
arccos
(
z
/
out
[...,
2
])
# theta
out
[...,
1
]
=
np
.
arctan2
(
y
,
x
)
# phi
return
out
else
:
raise
ValueError
(
'Unknown conversion:'
+
str
(
p_from
)
+
' to '
+
str
(
p_to
))
else
:
raise
ValueError
(
'Only dimension n=2 and n=3 supported for now.'
)
def
linspace
(
b
,
convention
):
pass
lie_learn/lie_learn/spaces/spherical_quadrature.pyx
0 → 100755
View file @
b5881ee2
from
lie_learn.representations.SO3.spherical_harmonics
import
rsh
import
numpy
as
np
cimport
numpy
as
np
def
estimate_spherical_quadrature_weights
(
sampling_set
,
max_bandwidth
,
normalization
=
'quantum'
,
condon_shortley
=
True
,
verbose
=
True
):
"""
:param sampling_set:
:param max_bandwith:
:return:
"""
cdef
int
l
cdef
int
m
cdef
int
ll
cdef
int
mm
cdef
int
i
cdef
int
M
=
sampling_set
.
shape
[
0
]
cdef
int
N
=
max_bandwidth
cdef
int
N_total
=
(
N
+
1
)
**
2
# = sum_l=0^N (2l + 1)
cdef
np
.
ndarray
[
np
.
float64_t
,
ndim
=
2
]
l_array
=
np
.
empty
((
N_total
,
1
))
cdef
np
.
ndarray
[
np
.
float64_t
,
ndim
=
2
]
m_array
=
np
.
empty
((
N_total
,
1
))
theta
=
sampling_set
[:,
0
]
phi
=
sampling_set
[:,
1
]
if
verbose
:
print
'Computing index arrays...'
i
=
0
for
l
in
range
(
N
+
1
):
for
m
in
range
(
-
l
,
l
+
1
):
l_array
[
i
,
0
]
=
l
m_array
[
i
,
0
]
=
m
i
+=
1
if
verbose
:
print
'Computing spherical harmonics...'
Y
=
rsh
(
l_array
,
m_array
,
theta
[
None
,
:],
phi
[
None
,
:],
normalization
=
normalization
,
condon_shortley
=
condon_shortley
)
if
verbose
:
print
'Computing least squares input'
B
=
np
.
empty
((
N_total
**
2
,
M
))
t
=
np
.
empty
(
N_total
**
2
)
i
=
0
#print M, N, N_total
#print theta[None, :].shape
#print phi[None, :].shape
#print Y.shape
#print B.shape
#print t.shape
for
l
in
range
(
N
+
1
):
for
m
in
range
(
-
l
,
l
+
1
):
rlm
=
Y
[
l
**
2
+
l
+
m
,
:]
for
ll
in
range
(
N
+
1
):
for
mm
in
range
(
-
ll
,
ll
+
1
):
B
[
i
,
:]
=
rlm
*
Y
[
ll
**
2
+
ll
+
mm
,
:]
t
[
i
]
=
float
(
ll
==
l
and
mm
==
m
)
i
+=
1
if
verbose
:
print
'Computing least squares solution'
return
np
.
linalg
.
lstsq
(
B
,
t
)
lie_learn/lie_learn/spectral/FFTBase.py
0 → 100755
View file @
b5881ee2
class
FFTBase
(
object
):
def
__init__
(
self
):
pass
def
analyze
(
self
,
f
):
raise
NotImplementedError
(
'FFTBase.analyze should be implemented in subclass'
)
def
synthesize
(
self
,
f_hat
):
raise
NotImplementedError
(
'FFTBase.synthesize should be implemented in subclass'
)
lie_learn/lie_learn/spectral/PolarFFT.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
.FFTBase
import
FFTBase
from
pynfft.nfft
import
NFFT
# UNFINISHED
class
PolarFFT
(
FFTBase
):
def
__init__
(
self
,
nx
,
ny
,
nt
,
nr
):
# Initialize the non-equispaced FFT
self
.
nfft
=
NFFT
(
N
=
(
nx
,
ny
),
M
=
nx
*
ny
,
n
=
None
,
m
=
12
,
flags
=
None
)
# Set up the polar sampling grid
theta
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
nt
)
r
=
np
.
linspace
(
0
,
1.
,
nr
)
T
,
R
=
np
.
meshgrid
(
theta
,
r
)
self
.
nfft
.
x
=
np
.
c_
[
T
[...,
None
],
R
[...,
None
]].
flatten
()
self
.
nfft
.
precompute
()
def
analyze
(
self
,
f
):
self
.
nfft
.
f_hat
=
f
f_hat
=
self
.
nfft
.
forward
()
return
f_hat
def
synthesize
(
self
,
f_hat
):
self
.
nfft
.
f
=
f_hat
f
=
self
.
nfft
.
adjoint
()
return
f_hat
lie_learn/lie_learn/spectral/S2FFT.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
scipy.fftpack
import
fft
,
ifft
,
fftshift
from
lie_learn.spectral.FFTBase
import
FFTBase
import
lie_learn.spaces.S2
as
S2
from
lie_learn.representations.SO3.spherical_harmonics
import
csh
,
sh
class
S2_FT_Naive
(
FFTBase
):
"""
The most naive implementation of the discrete spherical Fourier transform:
explicitly construct the Fourier matrix F and multiply by it to perform the Fourier transform.
"""
def
__init__
(
self
,
L_max
,
grid_type
=
'Gauss-Legendre'
,
field
=
'real'
,
normalization
=
'quantum'
,
condon_shortley
=
'cs'
):
super
().
__init__
()
self
.
b
=
L_max
+
1
# Compute a grid of spatial sampling points and associated quadrature weights
beta
,
alpha
=
S2
.
meshgrid
(
b
=
self
.
b
,
grid_type
=
grid_type
)
self
.
w
=
S2
.
quadrature_weights
(
b
=
self
.
b
,
grid_type
=
grid_type
)
self
.
spatial_grid_shape
=
beta
.
shape
self
.
num_spatial_points
=
beta
.
size
# Determine for which degree and order we want the spherical harmonics
irreps
=
np
.
arange
(
self
.
b
)
# TODO find out upper limit for exact integration for each grid type
ls
=
[[
ls
]
*
(
2
*
ls
+
1
)
for
ls
in
irreps
]
ls
=
np
.
array
([
ll
for
sublist
in
ls
for
ll
in
sublist
])
# 0, 1, 1, 1, 2, 2, 2, 2, 2, ...
ms
=
[
list
(
range
(
-
ls
,
ls
+
1
))
for
ls
in
irreps
]
ms
=
np
.
array
([
mm
for
sublist
in
ms
for
mm
in
sublist
])
# 0, -1, 0, 1, -2, -1, 0, 1, 2, ...
self
.
num_spectral_points
=
ms
.
size
# This equals sum_{l=0}^{b-1} 2l+1 = b^2
# In one shot, sample the spherical harmonics at all spectral (l, m) and spatial (beta, alpha) coordinates
self
.
Y
=
sh
(
ls
[
None
,
None
,
:],
ms
[
None
,
None
,
:],
beta
[:,
:,
None
],
alpha
[:,
:,
None
],
field
=
field
,
normalization
=
normalization
,
condon_shortley
=
condon_shortley
==
'cs'
)
# Convert to a matrix
self
.
Ymat
=
self
.
Y
.
reshape
(
self
.
num_spatial_points
,
self
.
num_spectral_points
)
def
analyze
(
self
,
f
):
return
self
.
Ymat
.
T
.
conj
().
dot
((
f
*
self
.
w
).
flatten
())
def
synthesize
(
self
,
f_hat
):
return
self
.
Ymat
.
dot
(
f_hat
).
reshape
(
self
.
spatial_grid_shape
)
def
setup_legendre_transform
(
b
):
"""
Compute a set of matrices containing coefficients to be used in a discrete Legendre transform.
The discrete Legendre transform of a data vector s[k] (k=0, ..., 2b-1) is defined as
s_hat(l, m) = sum_{k=0}^{2b-1} P_l^m(cos(beta_k)) s[k]
for l = 0, ..., b-1 and -l <= m <= l,
where P_l^m is the associated Legendre function of degree l and order m,
beta_k = ...
Computing Fourier Transforms and Convolutions on the 2-Sphere
J.R. Driscoll, D.M. Healy
FFTs for the 2-Sphere–Improvements and Variations
D.M. Healy, Jr., D.N. Rockmore, P.J. Kostelec, and S. Moore
:param b: bandwidth of the transform
:return: lt, an array of shape (N, 2b), containing samples of the Legendre functions,
where N is the number of spectral points for a signal of bandwidth b.
"""
dim
=
np
.
sum
(
np
.
arange
(
b
)
*
2
+
1
)
lt
=
np
.
empty
((
2
*
b
,
dim
))
beta
,
_
=
S2
.
linspace
(
b
,
grid_type
=
'Driscoll-Healy'
)
sample_points
=
np
.
cos
(
beta
)
# TODO move quadrature weight computation to S2.py
weights
=
[(
1.
/
b
)
*
np
.
sin
(
np
.
pi
*
j
*
0.5
/
b
)
*
np
.
sum
([
1.
/
(
2
*
l
+
1
)
*
np
.
sin
((
2
*
l
+
1
)
*
np
.
pi
*
j
*
0.5
/
b
)
for
l
in
range
(
b
)])
for
j
in
range
(
2
*
b
)]
weights
=
np
.
array
(
weights
)
zeros
=
np
.
zeros_like
(
sample_points
)
i
=
0
for
l
in
range
(
b
):
for
m
in
range
(
-
l
,
l
+
1
):
# Z = np.sqrt(((2 * l + 1) * factorial(l - m)) / float(4 * np.pi * factorial(l + m))) * np.pi / 2
# lt[i, :] = lpmv(m, l, sample_points) * weights * Z
# The spherical harmonics code appears to be more stable than the (unnormalized) associated Legendre
# function code.
# (Note: the spherical harmonics evaluated at alpha=0 is the associated Legendre function))
lt
[:,
i
]
=
csh
(
l
,
m
,
beta
,
zeros
,
normalization
=
'seismology'
).
real
*
weights
*
np
.
pi
/
2
i
+=
1
return
lt
def
setup_legendre_transform_indices
(
b
):
ms
=
[
list
(
range
(
-
ls
,
ls
+
1
))
for
ls
in
range
(
b
)]
ms
=
[
mm
for
sublist
in
ms
for
mm
in
sublist
]
# 0, -1, 0, 1, -2, -1, 0, 1, 2, ...
ms
=
[
mm
%
(
2
*
b
)
for
mm
in
ms
]
return
ms
def
sphere_fft
(
f
,
lt
=
None
,
lti
=
None
):
"""
Compute the Spherical Fourier transform of f.
We use complex, seismology-normalized, centered spherical harmonics, which are orthonormal (see rep_bases.py).
The spherical Fourier transform is defined:
\hat{f}_l^m = int_0^pi dbeta sin(beta) int_0^2pi dalpha f(beta, alpha) Y_l^{m*}(beta, alpha)
(where we use the invariant area element dOmega = sin(beta) dbeta dalpha for the 2-sphere)
We have Y_l^m(beta, alpha) = P_l^m(cos(beta)) * e^{im alpha}, where P_l^m is the associated Legendre function,
so we can rewrite:
\hat{f}_l^m = int_0^pi dbeta sin(beta) (int_0^2pi dalpha f(beta, alpha) e^{im alpha} ) P_l^m(cos(beta))
The integral over alpha can be evaluated by FFT:
\b
ar{f}(beta_k, m) = int_0^2pi dalpha f(beta_k, alpha) e^{im alpha} = FFT(f, axis=1)[beta_k, m]
Then we have
\hat{f}_l^m = int_0^pi sin(beta) dbeta
\b
ar{f}(beta, m) P_l^m(cos(beta))
= sum_k
\b
ar{f}[beta_k, m] P_l^m(cos(beta_k)) w_k
For appropriate quadrature weights w_k. This sum is called the discrete Legendre transform of
\b
ar{f}
We return \hat{f} as a flat vector. Hence, the precomputed P_l^m(cos(beta_k)) w_k is stored as an array of with a
combined (l, m)-axis and a k axis. We bring the data
\b
ar{f}[beta_k, m] into the same form, by indexing with lti
and then reduce over the beta_k axis.
Main source:
Engineering Applications of Noncommutative Harmonic Analysis.
4.7.2 - Orthogonal Expansions on the Sphere
G.S. Chrikjian, A.B. Kyatkin
Further information:
SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
Generalized FFTs-a survey of some recent results
Maslen & Rockmore
Computing Fourier transforms and convolutions on the 2-sphere.
Driscoll, J., & Healy, D. (1994).
:param f: array of samples of the function to be transformed. Shape (2 * b, 2 * b). grid_type: Driscoll-Healy
:param lt: precomputed Legendre transform matrices, from setup_legendre_transform().
:param lti: precomputed Legendre transform indices, from setup_legendre_transform_indices().
:return: f_hat, the spherical Fourier transform of f. This is an array of size sum_l=0^{b-1} 2 l + 1.
the coefficients are ordered as (l=0, m=0), (l=1, m=-1), (l=1, m=0), (l=1,m=1), ...
"""
assert
f
.
shape
[
-
2
]
==
f
.
shape
[
-
1
]
assert
f
.
shape
[
-
2
]
%
2
==
0
b
=
f
.
shape
[
-
2
]
//
2
if
lt
is
None
:
lt
=
setup_legendre_transform
(
b
)
if
lti
is
None
:
lti
=
setup_legendre_transform_indices
(
b
)
# First, FFT along the alpha axis (last axis)
# This gives the array f_bar with axes for beta and m.
f_bar
=
fft
(
f
,
axis
=-
1
)
# Perform Legendre transform
f_hat
=
(
f_bar
[...,
lti
]
*
lt
).
sum
(
axis
=-
2
)
return
f_hat
lie_learn/lie_learn/spectral/S2FFT_NFFT.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
.FFTBase
import
FFTBase
from
pynfft
import
nfsft
from
lie_learn.spaces.spherical_quadrature
import
estimate_spherical_quadrature_weights
from
lie_learn.representations.SO3.irrep_bases
import
change_of_basis_matrix
,
change_of_basis_function
class
S2FFT_NFFT
(
FFTBase
):
def
__init__
(
self
,
L_max
,
x
,
w
=
None
):
"""
:param L_max: maximum spherical harmonic degree
:param x: coordinates on spherical / spatial grid
:param w: quadrature weights for the grid x
"""
# If x is a list (generated by S2.meshgrid), convert to (M, 2) array
if
isinstance
(
x
,
list
):
x
=
np
.
c_
[
x
[
0
].
flatten
()[:,
None
],
x
[
1
].
flatten
()[:,
None
]]
# The NFSFT class can synthesis / analyze functions in terms of
# NFFT-normalized, centered, complex spherical harmonics without Condon-Shortley phase.
self
.
_nfsft
=
nfsft
.
NFSFT
(
N
=
L_max
,
x
=
x
)
# Compute a change-of-basis matrix from the NFFT spherical harmonics to our prefered choice, the
# quantum-normalized, centered, real spherical harmonics with Condon-Shortley phase.
#TODO: change this to change_of_basis_function (test that it works..)
#self._c2r = change_of_basis_matrix(np.arange(L_max + 1),
# frm=('complex', 'nfft', 'centered', 'nocs'),
# to=('real', 'quantum', 'centered', 'cs'))
#self._r2c = change_of_basis_matrix(np.arange(L_max + 1),
# to=('complex', 'nfft', 'centered', 'nocs'),
# frm=('real', 'quantum', 'centered', 'cs'))
#self._c = change_of_basis_matrix(np.arange(L_max + 1),
# frm=('real', 'nfft', 'centered', 'cs'),
# to=('complex', 'quantum', 'centered', 'nocs'))
self
.
_c2r_func
=
change_of_basis_function
(
np
.
arange
(
L_max
+
1
),
frm
=
(
'complex'
,
'nfft'
,
'centered'
,
'nocs'
),
to
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
))
#self._r2c_func = change_of_basis_function(np.arange(L_max + 1),
# frm=('real', 'quantum', 'centered', 'cs'),
# to=('complex', 'nfft', 'centered', 'nocs'))
# In the synthesize() function, we will need c2r.conj().T as a function (not a matrix).
# It happens to be the case that the following is equal to c2r.conj().T:
c2r_conj_T
=
change_of_basis_function
(
np
.
arange
(
L_max
+
1
),
frm
=
(
'real'
,
'nfft'
,
'centered'
,
'cs'
),
to
=
(
'complex'
,
'quantum'
,
'centered'
,
'nocs'
))
self
.
_c2r_T
=
lambda
vec
:
c2r_conj_T
(
vec
.
conj
()).
conj
()
if
w
is
None
:
# Precompute quadrature weights
self
.
w
=
estimate_spherical_quadrature_weights
(
sampling_set
=
x
,
max_bandwidth
=
L_max
,
normalization
=
'quantum'
,
condon_shortley
=
True
)[
0
]
else
:
self
.
w
=
w
.
flatten
()
self
.
x
=
x
self
.
L_max
=
L_max
def
analyze
(
self
,
f
):
# We want to perform the *weighted* adjoint FFT, so that we get the exact Fourier coefficients
# (at least for a proper sampling grid such as Clenshaw-Curtis or Gauss-Legendre and the respective weights)
# Hence, the function to be transformed is f * w
self
.
_nfsft
.
f
=
f
*
self
.
w
# Expand the weighted function in terms of the conjugate of
# NFFT-normalized, centered, complex spherical harmonics without Condon-Shortley phase:
# a_lm = sum_i=0^M Y_lm(theta_i, phi_i).conj() * w_i * f(theta_i, phi_i)
self
.
_nfsft
.
adjoint
()
# The computed Fourier components a_lm are with respect to the basis of NFFT spherical harmonics,
# so change the basis.
# Let Y denote the M by (L_max+1)^2 matrix of NFFT spherical harmonics.
# then a = Y.conj().T.dot(f * w), as computed by _nfsft.adjoint()
# Since, Y.conj().T = r2c.conj().dot(R.T), we have a = r2c.conj().dot(R.T.dot(f * w))
# To cancel the r2c.conj(), we multiply with c2r.conj()
#a = self._c2r.conj().dot(self._nfsft.get_f_hat_flat()).real
#b = self._c2r_func(self._nfsft.get_f_hat_flat().conj()).conj().real
#print 'DIFF', np.sum(np.abs(a-b))
#return self._c2r.conj().dot(self._nfsft.get_f_hat_flat()).real
return
self
.
_c2r_func
(
self
.
_nfsft
.
get_f_hat_flat
().
conj
()).
conj
().
real
def
synthesize
(
self
,
f_hat
):
# self._nfsft.trafo() computes the synthesis / forward transform using NFFT complex SH:
# f = Y f_hat, where Y is the M by (L_max+1)^2 matrix of complex NFFT spherical harmonics.
# We have R.T = c2r.dot(Y.T), so f = R.dot(f_hat) = Y.dot(c2r.T.dot(f_hat))
#cfh = self._c2r.T.dot(f_hat)
cfh
=
self
.
_c2r_T
(
f_hat
)
self
.
_nfsft
.
set_f_hat_flat
(
cfh
)
f
=
self
.
_nfsft
.
trafo
(
use_dft
=
False
,
return_copy
=
True
)
return
f
.
real
\ No newline at end of file
lie_learn/lie_learn/spectral/S2_conv.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
import
lie_learn.spaces.S2
as
S2
import
lie_learn.groups.SO3
as
SO3
from
lie_learn.spectral.S2FFT
import
S2_FT_Naive
from
lie_learn.spectral.SO3FFT_Naive
import
SO3_FT_Naive
def
conv_test
():
"""
:return:
"""
from
lie_learn.spectral.SO3FFT_Naive
import
SO3_FT_Naive
b
=
10
f1
=
np
.
ones
((
2
*
b
+
2
,
b
+
1
))
f2
=
np
.
ones
((
2
*
b
+
2
,
b
+
1
))
s2_fft
=
S2_FT_Naive
(
L_max
=
b
-
1
,
grid_type
=
'Gauss-Legendre'
,
field
=
'real'
,
normalization
=
'quantum'
,
condon_shortley
=
'cs'
)
so3_fft
=
SO3_FT_Naive
(
L_max
=
b
-
1
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
)
# Spherical Fourier transform
f1_hat
=
s2_fft
.
analyze
(
f1
)
f2_hat
=
s2_fft
.
analyze
(
f2
)
# Perform block-wise outer product
f12_hat
=
[]
for
l
in
range
(
b
):
f1_hat_l
=
f1_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f2_hat_l
=
f2_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f12_hat_l
=
f1_hat_l
[:,
None
]
*
f2_hat_l
[
None
,
:].
conj
()
f12_hat
.
append
(
f12_hat_l
)
# Inverse SO(3) Fourier transform
f12
=
so3_fft
.
synthesize
(
f12_hat
)
return
f12
def
spectral_S2_conv
(
f1
,
f2
,
s2_fft
=
None
,
so3_fft
=
None
):
"""
Compute the convolution of two functions on the 2-sphere.
Let f1 : S^2 -> R and f2 : S^2 -> R, then the convolution is defined as
f1 * f2(g) = int_{S^2} f1(x) f2(g^{-1} x) dx,
where g in SO(3) and dx is the normalized Haar measure on S^2.
The convolution is computed by a Fourier transform.
It can be shown that the SO(3)-Fourier transform of the convolution f1 * f2 is equal to the outer product
of the spherical Fourier transform of f1 and f2.
Specifically, let f1_hat be the spherical FT of f1 and f2_hat the spherical FT of f2.
These vectors are split into chunks of dimension 2l+1, for l=0, ..., b (the bandwidth)
For each degree, we take the outer product to obtain a (2l+1) x (2l+1) matrix, which is the degree-l
block of the FT of f1*f2.
For more details, see our note on "Convolution on S^2 and SO(3)"
:param f1:
:param f2:
:param s2_fft:
:param so3_fft:
:return:
"""
b
=
f1
.
shape
[
1
]
-
1
# TODO we assume a Gauss-Legendre grid for S^2 here
if
s2_fft
is
None
:
s2_fft
=
S2_FT_Naive
(
L_max
=
b
-
1
,
grid_type
=
'Gauss-Legendre'
,
field
=
'real'
,
normalization
=
'quantum'
,
condon_shortley
=
'cs'
)
if
so3_fft
is
None
:
so3_fft
=
SO3_FT_Naive
(
L_max
=
b
-
1
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
)
# Spherical Fourier transform
f1_hat
=
s2_fft
.
analyze
(
f1
)
f2_hat
=
s2_fft
.
analyze
(
f2
)
# Perform block-wise outer product
f12_hat
=
[]
for
l
in
range
(
b
):
f1_hat_l
=
f1_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f2_hat_l
=
f2_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f12_hat_l
=
f1_hat_l
[:,
None
]
*
f2_hat_l
[
None
,
:].
conj
()
f12_hat
.
append
(
f12_hat_l
)
# Inverse SO(3) Fourier transform
return
so3_fft
.
synthesize
(
f12_hat
)
def
naive_S2_conv
(
f1
,
f2
,
alpha
,
beta
,
gamma
,
g_parameterization
=
'EA323'
):
"""
Compute int_S^2 f1(x) f2(g^{-1} x)* dx,
where x = (theta, phi) is a point on the sphere S^2,
and g = (alpha, beta, gamma) is a point in SO(3) in Euler angle parameterization
:param f1, f2: functions to be convolved
:param alpha, beta, gamma: the rotation at which to evaluate the result of convolution
:return:
"""
# This fails
def
integrand
(
theta
,
phi
):
g_inv
=
SO3
.
invert
((
alpha
,
beta
,
gamma
),
parameterization
=
g_parameterization
)
g_inv_theta
,
g_inv_phi
,
_
=
SO3
.
transform_r3
(
g
=
g_inv
,
x
=
(
theta
,
phi
,
1.
),
g_parameterization
=
g_parameterization
,
x_parameterization
=
'S'
)
return
f1
(
theta
,
phi
)
*
f2
(
g_inv_theta
,
g_inv_phi
).
conj
()
return
S2
.
integrate
(
f
=
integrand
,
normalize
=
True
)
def
naive_S2_conv_v2
(
f1
,
f2
,
alpha
,
beta
,
gamma
,
g_parameterization
=
'EA323'
):
"""
Compute int_S^2 f1(x) f2(g^{-1} x)* dx,
where x = (theta, phi) is a point on the sphere S^2,
and g = (alpha, beta, gamma) is a point in SO(3) in Euler angle parameterization
:param f1, f2: functions to be convolved
:param alpha, beta, gamma: the rotation at which to evaluate the result of convolution
:return:
"""
theta
,
phi
=
S2
.
meshgrid
(
b
=
3
,
grid_type
=
'Gauss-Legendre'
)
w
=
S2
.
quadrature_weights
(
b
=
3
,
grid_type
=
'Gauss-Legendre'
)
print
(
theta
.
shape
,
phi
.
shape
)
s2_coords
=
np
.
c_
[
theta
[...,
None
],
phi
[...,
None
]]
print
(
s2_coords
.
shape
)
r3_coords
=
np
.
c_
[
theta
[...,
None
],
phi
[...,
None
],
np
.
ones_like
(
theta
)[...,
None
]]
# g_inv = SO3.invert((alpha, beta, gamma), parameterization=g_parameterization)
# g_inv = (-gamma, -beta, -alpha)
g_inv
=
(
alpha
,
beta
,
gamma
)
# wrong
ginvx
=
SO3
.
transform_r3
(
g
=
g_inv
,
x
=
r3_coords
,
g_parameterization
=
g_parameterization
,
x_parameterization
=
'S'
)
print
(
ginvx
.
shape
)
g_inv_theta
=
ginvx
[...,
0
]
g_inv_phi
=
ginvx
[...,
1
]
g_inv_r
=
ginvx
[...,
2
]
print
(
g_inv_theta
,
g_inv_phi
,
g_inv_r
)
f1_grid
=
f1
(
theta
,
phi
)
f2_grid
=
f2
(
g_inv_theta
,
g_inv_phi
)
print
(
f1_grid
.
shape
,
f2_grid
.
shape
,
w
.
shape
)
return
np
.
sum
(
f1_grid
*
f2_grid
*
w
)
lie_learn/lie_learn/spectral/SE2FFT.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
# from numpy.fft import fft, fft2, ifft, ifft2, fftshift
from
spectral.T1FFT
import
T1FFT
from
spectral.T2FFT
import
T2FFT
from
scipy.ndimage.interpolation
import
map_coordinates
from
spectral.FFTBase
import
FFTBase
from
spectral.fourier_interpolation
import
FourierInterpolator
import
groups.SE2
as
SE2
def
bilinear_interpolate
(
f
,
x
,
y
):
x
=
np
.
asarray
(
x
)
y
=
np
.
asarray
(
y
)
x0
=
np
.
floor
(
x
).
astype
(
int
)
x1
=
x0
+
1
y0
=
np
.
floor
(
y
).
astype
(
int
)
y1
=
y0
+
1
x0
=
np
.
clip
(
x0
,
0
,
f
.
shape
[
1
]
-
1
)
x1
=
np
.
clip
(
x1
,
0
,
f
.
shape
[
1
]
-
1
)
y0
=
np
.
clip
(
y0
,
0
,
f
.
shape
[
0
]
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
f
.
shape
[
0
]
-
1
)
Ia
=
f
[
y0
,
x0
]
Ib
=
f
[
y1
,
x0
]
Ic
=
f
[
y0
,
x1
]
Id
=
f
[
y1
,
x1
]
wa
=
(
x1
-
x
)
*
(
y1
-
y
)
wb
=
(
x1
-
x
)
*
(
y
-
y0
)
wc
=
(
x
-
x0
)
*
(
y1
-
y
)
wd
=
(
x
-
x0
)
*
(
y
-
y0
)
print
(
x0
.
shape
,
y0
.
shape
,
x1
.
shape
,
y1
.
shape
,
x
.
shape
,
y
.
shape
)
print
(
Ia
.
shape
,
Ib
.
shape
,
Ic
.
shape
,
Id
.
shape
)
print
(
wa
.
shape
,
wb
.
shape
,
wc
.
shape
,
wd
.
shape
)
return
wa
[...,
None
]
*
Ia
+
wb
[...,
None
]
*
Ib
+
wc
[...,
None
]
*
Ic
+
wd
[...,
None
]
*
Id
def
mul
(
fh1
,
fh2
):
assert
fh1
.
shape
==
fh2
.
shape
# The axes of fh are (r, p, q)
# For each r, we multiply the infinite dimensional matrices indexed by (p, q), assuming the values are zero
# outside the range stored.
# Thus, the p-axis of the second array fh2 must be truncated at both sides so that we can compute fh1.dot(fh2),
# and so that the 0-frequency q-component of fh1 lines up with the zero-fruency p-component of fh2.
p0
=
fh1
.
shape
[
1
]
//
2
# Indices of the zero frequency component
q0
=
fh1
.
shape
[
2
]
//
2
# The lower and upper bound of the p-range
a
=
p0
-
q0
b
=
p0
+
np
.
ceil
(
fh2
.
shape
[
2
]
/
2.
)
fh12
=
[]
for
i
in
range
(
fh1
.
shape
[
0
]):
fh12
.
append
(
fh1
[
i
,
:,
:].
dot
(
fh2
[
i
,
a
:
b
,
:]))
fh12
=
np
.
c_
[
fh12
]
#.transpose(2, 0, 1)
return
fh12
def
mulT
(
fh1
,
fh2
):
assert
fh1
.
shape
==
fh2
.
shape
# The axes of fh are (r, p, q)
# For each r, we multiply the infinite dimensional matrices indexed by (p, q), assuming the values are zero
# outside the range stored.
# Thus, the p-axis of the second array fh2 must be truncated at both sides so that we can compute fh1.dot(fh2),
# and so that the 0-frequency q-component of fh1 lines up with the zero-fruency p-component of fh2.
p0
=
fh1
.
shape
[
1
]
//
2
# Indices of the zero frequency component
q0
=
fh1
.
shape
[
2
]
//
2
# The lower and upper bound of the p-range
a
=
p0
-
q0
b
=
p0
+
np
.
ceil
(
fh2
.
shape
[
2
]
/
2.
)
fh12
=
[]
for
i
in
range
(
fh1
.
shape
[
0
]):
fh12
.
append
(
fh1
[
i
,
:,
:].
dot
(
fh2
[
i
,
:,
:].
T
)[:,
a
:
b
])
fh12
=
np
.
c_
[
fh12
]
#.transpose(2, 0, 1)
return
fh12
def
conv_test
():
f
,
f1c
,
f1p
,
f2
,
f2f
,
fh
,
fi
,
f1ci
,
f1pi
,
f2i
,
f2fi
,
fhi
=
test
()
fh2
=
mulT
(
fh
,
fh
)
F
=
SE2_FFT
(
spatial_grid_size
=
(
40
,
40
,
42
),
interpolation_method
=
'spline'
,
oversampling_factor
=
5
)
fi
,
f1ci
,
f1pi
,
f2i
,
f2fi
,
fhi
=
F
.
synthesize
(
fh2
)
from
utils.visualize
import
plotmat
for
i
in
range
(
fi
.
shape
[
0
]):
plotmat
(
fi
[:,
:,
i
].
real
,
i
,
range
=
(
np
.
min
(
fi
.
real
),
np
.
max
(
fi
.
real
)))
def
SE2_matrix_element
(
r
,
p
,
q
,
tau
,
theta
):
from
scipy.special
import
jv
a
=
np
.
sqrt
(
tau
[
0
]
**
2
+
tau
[
1
]
**
2
)
phi
=
np
.
angle
(
z
=
tau
[
0
]
+
1j
*
tau
[
1
],
deg
=
0
)
return
1j
**
(
q
-
p
)
*
np
.
exp
(
1j
*
((
p
-
q
)
*
phi
+
q
*
theta
))
*
jv
(
p
-
q
,
r
*
a
)
def
SE2_matrix_element_grid
(
r
,
p
,
q
,
spatial_grid_size
=
(
10
,
10
,
10
)):
mat
=
np
.
zeros
(
spatial_grid_size
,
dtype
=
'complex'
)
taus
=
np
.
linspace
(
-
1
,
1
,
spatial_grid_size
[
0
])
thetas
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
spatial_grid_size
[
2
])
for
itau1
in
range
(
spatial_grid_size
[
0
]):
for
itau2
in
range
(
spatial_grid_size
[
1
]):
for
itheta
in
range
(
spatial_grid_size
[
2
]):
mat
[
itau1
,
itau2
,
itheta
]
=
SE2_matrix_element
(
r
,
p
,
q
,
tau
=
(
taus
[
itau1
],
taus
[
itau2
]),
theta
=
thetas
[
itheta
])
return
mat
def
SE2_matrix_element_chirkijian
(
r
,
p
,
q
,
tau
,
theta
):
# this appears to be the complex conjugate of what I derived
from
scipy.special
import
jv
# Should compute SE2 matrix elements, by eq. 10.3 of Chirikjian & Kyatkin. Not yet tested
a
=
np
.
sqrt
(
tau
[
0
]
**
2
+
tau
[
1
]
**
2
)
phi
=
np
.
angle
(
z
=
tau
[
0
]
+
1j
*
tau
[
1
],
deg
=
0
)
return
1j
**
(
q
-
p
)
*
np
.
exp
(
-
1j
*
(
q
*
theta
+
(
p
-
q
)
*
phi
))
*
jv
(
q
-
p
,
r
*
a
)
def
pix_to_ndc
(
C
,
w
,
h
,
flip_y
=
True
):
Xpix
=
C
[...,
0
]
Ypix
=
C
[...,
1
]
Xndc
=
(
2.
*
Xpix
-
w
)
/
w
Yndc
=
(
2.
*
Ypix
-
h
)
/
h
*
(
-
1
)
**
flip_y
return
np
.
c_
[
Xndc
[...,
None
],
Yndc
[...,
None
]]
def
ndc_to_pix
(
C
,
w
,
h
,
flip_y
=
True
):
Xndc
=
C
[...,
0
]
Yndc
=
C
[...,
1
]
*
(
-
1
)
**
flip_y
Xpix
=
(
Xndc
+
1
)
*
0.5
*
w
Ypix
=
(
Yndc
+
1
)
*
0.5
*
h
return
np
.
c_
[
Xpix
[...,
None
],
Ypix
[...,
None
]]
def
R2_SE2_convolve_naive
(
f1
,
f2
,
t_res
=
21
,
r_res
=
21
,
f_res
=
None
):
# Compute int_R2 f1(x) f2(x^{-1} g) dx = int_R2 f1(gx) f2(x^{-1}) dx
w
=
f1
.
shape
[
0
]
-
1
;
h
=
f1
.
shape
[
1
]
-
1
if
f_res
is
None
:
f_res
=
f1
.
shape
[
0
]
# make a coordinate grid
X
,
Y
=
np
.
meshgrid
(
np
.
linspace
(
-
1
,
1
,
f_res
),
np
.
linspace
(
-
1
,
1
,
f_res
),
indexing
=
'ij'
)
C
=
np
.
c_
[
X
[...,
None
],
Y
[...,
None
]]
# Create a flipped image f2(x^{-1})
Xi
,
Yi
=
np
.
meshgrid
(
np
.
linspace
(
1
,
-
1
,
f1
.
shape
[
0
]),
np
.
linspace
(
1
,
-
1
,
f1
.
shape
[
1
]),
indexing
=
'ij'
)
Cinv
=
np
.
c_
[
X
[...,
None
],
Y
[...,
None
]]
Cinv_pix
=
ndc_to_pix
(
Cinv
,
w
=
w
,
h
=
h
)
f2inv
=
map_coordinates
(
f2
,
Cinv_pix
.
transpose
(
2
,
0
,
1
),
order
=
0
,
mode
=
'constant'
,
cval
=
0.0
)
out
=
np
.
empty
((
r_res
,
t_res
,
t_res
))
translations
=
np
.
linspace
(
-
1
,
1
,
t_res
)
rotations
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
r_res
,
endpoint
=
False
)
for
t1i
in
range
(
translations
.
size
):
for
t2i
in
range
(
translations
.
size
):
for
thetai
in
range
(
rotations
.
size
):
t1
=
translations
[
t1i
]
t2
=
translations
[
t2i
]
theta
=
rotations
[
thetai
]
# Transform the sampling grid:
gC
=
SE2
.
transform
(
g
=
(
theta
,
t1
,
t2
),
g_parameterization
=
'rotation-translation'
,
x
=
C
,
x_parameterization
=
'cartesian'
)
# Map normalized device coordinates to array indices
gC_pix
=
ndc_to_pix
(
gC
,
w
=
w
,
h
=
h
)
# Evaluate f1 at the transformed grid:
f1_gx
=
map_coordinates
(
f1
,
order
=
1
,
coordinates
=
gC_pix
.
transpose
(
2
,
0
,
1
),
mode
=
'constant'
,
cval
=
0.0
)
# Compute dot product:
out
[
thetai
,
t1i
,
t2i
]
=
(
f1_gx
*
f2inv
).
sum
()
return
out
def
map_wrap
(
f
,
coords
):
# Create an agumented array, where the last row and column are added at the beginning of the axes
fa
=
np
.
empty
((
f
.
shape
[
0
]
+
1
,
f
.
shape
[
1
]
+
1
))
#fa[1:, 1:] = f
#fa[0, 1:] = f[-1, :]
#fa[1:, 0] = f[:, -1]
#f[0, 0] = f[-1, -1]
fa
[:
-
1
,
:
-
1
]
=
f
fa
[
-
1
,
:
-
1
]
=
f
[
0
,
:]
fa
[:
-
1
,
-
1
]
=
f
[:,
0
]
fa
[
-
1
,
-
1
]
=
f
[
0
,
0
]
# Wrap coordinates
wrapped_coords_x
=
coords
[
0
,
...]
%
f
.
shape
[
0
]
wrapped_coords_y
=
coords
[
1
,
...]
%
f
.
shape
[
1
]
wrapped_coords
=
np
.
r_
[
wrapped_coords_x
[
None
,
...],
wrapped_coords_y
[
None
,
...]]
# Interpolate
#return fa, wrapped_coords, map_coordinates(f, wrapped_coords, order=1, mode='constant', cval=np.nan, prefilter=False)
return
map_coordinates
(
fa
,
wrapped_coords
,
order
=
1
,
mode
=
'constant'
,
cval
=
np
.
nan
,
prefilter
=
False
)
def
test
():
f
=
np
.
zeros
((
40
,
40
,
42
))
f
[
19
:
21
,
10
:
30
,
:]
=
1.
F
=
SE2_FFT
(
spatial_grid_size
=
(
40
,
40
,
42
),
interpolation_method
=
'spline'
,
spline_order
=
1
,
oversampling_factor
=
5
)
f
,
f1c
,
f1p
,
f2
,
f2f
,
fh
=
F
.
analyze
(
f
)
fi
,
f1ci
,
f1pi
,
f2i
,
f2fi
,
fhi
=
F
.
synthesize
(
fh
)
print
(
np
.
sum
(
np
.
abs
(
f
-
fi
)))
return
f
,
f1c
,
f1p
,
f2
,
f2f
,
fh
,
fi
,
f1ci
,
f1pi
,
f2i
,
f2fi
,
fhi
def
test_phaseshift1
():
nx
=
20
ny
=
20
p0
=
nx
/
2
q0
=
ny
/
2
# Shows that when the image rotates around center (p0, q0), the FT also rotates around (p0, q0) (which corresponds
# to frequency (0, 0).
f1
=
np
.
zeros
((
nx
,
ny
))
f1
[
p0
-
1
,
q0
-
1
]
=
1.
f1
[
p0
,
q0
]
=
1.
f1
[
p0
+
1
,
p0
+
1
]
=
1.
f1
=
np
.
random
.
randn
(
nx
,
ny
)
+
1j
*
np
.
random
.
randn
(
nx
,
ny
)
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
p0
,
p0
+
f1
.
shape
[
0
])
%
f1
.
shape
[
0
],
np
.
arange
(
q0
,
q0
+
f1
.
shape
[
1
])
%
f1
.
shape
[
1
],
indexing
=
'ij'
)
f1shift
=
f1
[
X
,
Y
]
f1h
=
T2FFT
.
analyze
(
f1
)
f1sh
=
T2FFT
.
analyze
(
f1shift
)
# Do a phase shift and check that it is equal to the FT of the shifted image
delta
=
-
0.5
# we're shifting from [0, 1) to [-0.5, 0.5)
xi1
=
np
.
arange
(
-
np
.
floor
(
f1
.
shape
[
0
]
/
2.
),
np
.
ceil
(
f1
.
shape
[
0
]
/
2.
))
xi2
=
np
.
arange
(
-
np
.
floor
(
f1
.
shape
[
1
]
/
2.
),
np
.
ceil
(
f1
.
shape
[
1
]
/
2.
))
XI1
,
XI2
=
np
.
meshgrid
(
xi1
,
xi2
,
indexing
=
'ij'
)
phase
=
np
.
exp
(
-
2
*
np
.
pi
*
1j
*
delta
*
(
XI1
+
XI2
))
f1psh
=
f1h
*
phase
return
f1
,
f1shift
,
f1h
,
f1sh
,
f1psh
def
imrot
(
f
,
t
):
"""
Rotate array f around its center by t radians counterclockwise
"""
nx
=
f
.
shape
[
0
]
ny
=
f
.
shape
[
1
]
p0
=
nx
/
2
q0
=
ny
/
2
#X, Y = np.meshgrid(np.arange(p0, p0 + nx) % nx,
# np.arange(q0, q0 + ny) % ny,
# indexing='ij')
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
0
,
nx
),
np
.
arange
(
0
,
ny
),
indexing
=
'ij'
)
R
=
np
.
array
([[
np
.
cos
(
-
t
),
-
np
.
sin
(
-
t
)],
[
np
.
sin
(
-
t
),
np
.
cos
(
-
t
)]])
C
=
np
.
c_
[
X
[...,
None
],
Y
[...,
None
]]
-
np
.
array
([
p0
,
q0
])[
None
,
None
,
:]
RC
=
np
.
einsum
(
'ij,abj->abi'
,
R
,
C
)
+
np
.
array
([
p0
,
q0
])[
None
,
None
,
:]
#Rfr = map_coordinates(f.real, RC.transpose(2, 0, 1), order=1, mode='wrap')
#Rfi = map_coordinates(f.imag, RC.transpose(2, 0, 1), order=1, mode='wrap')
Rfr
=
map_wrap
(
f
.
real
,
RC
.
transpose
(
2
,
0
,
1
))
Rfi
=
map_wrap
(
f
.
imag
,
RC
.
transpose
(
2
,
0
,
1
))
return
Rfr
+
Rfi
*
1j
def
shift_fft
(
f
):
nx
=
f
.
shape
[
0
]
ny
=
f
.
shape
[
1
]
p0
=
nx
/
2
q0
=
ny
/
2
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
p0
,
p0
+
nx
)
%
nx
,
np
.
arange
(
q0
,
q0
+
ny
)
%
ny
,
indexing
=
'ij'
)
fs
=
f
[
X
,
Y
,
...]
return
T2FFT
.
analyze
(
fs
,
axes
=
(
0
,
1
))
def
shift_ifft
(
fh
):
nx
=
fh
.
shape
[
0
]
ny
=
fh
.
shape
[
1
]
p0
=
nx
/
2
q0
=
ny
/
2
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
-
p0
,
-
p0
+
nx
)
%
nx
,
np
.
arange
(
-
q0
,
-
q0
+
ny
)
%
ny
,
indexing
=
'ij'
)
fs
=
T2FFT
.
synthesize
(
fh
,
axes
=
(
0
,
1
))
f
=
fs
[
X
,
Y
,
...]
return
f
def
test_ft_rotation3
():
nx
=
20
ny
=
20
nt
=
10
p0
=
nx
/
2
q0
=
ny
/
2
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
0
,
nx
),
np
.
arange
(
0
,
ny
),
indexing
=
'ij'
)
C
=
np
.
c_
[
X
[...,
None
],
Y
[...,
None
]]
F
=
SE2_FFT
(
spatial_grid_size
=
(
nx
,
ny
,
nt
),
interpolation_method
=
'spline'
,
oversampling_factor
=
5
,
spline_order
=
1
)
fs
=
[]
f1cs
=
[]
f1ps
=
[]
f2s
=
[]
f2fs
=
[]
fhs
=
[]
for
i
in
range
(
36
):
t
=
i
*
2
*
np
.
pi
/
36.
R
=
np
.
array
([[
np
.
cos
(
-
t
),
-
np
.
sin
(
-
t
)],
[
np
.
sin
(
-
t
),
np
.
cos
(
-
t
)]])
RC
=
np
.
einsum
(
'ij,abj->abi'
,
R
,
C
-
np
.
array
([
p0
,
q0
])[
None
,
None
,
:])
+
np
.
array
([
p0
,
q0
])[
None
,
None
,
:]
#Rfr = map_wrap(f1.real, RC.transpose(2, 0, 1))
#Rfi = map_wrap(f1.imag, RC.transpose(2, 0, 1))
#Rf = Rfr + Rfi * 1j
Rf
=
np
.
exp
(
1j
*
2
*
np
.
pi
*
(
5
*
RC
[...,
0
,
None
]
*
np
.
ones
(
nt
)[
None
,
None
,
:])
/
nx
)
f
,
f1c
,
f1p
,
f2
,
f2f
,
f_hat
=
F
.
analyze
(
Rf
)
fs
.
append
(
f
)
f1cs
.
append
(
f1c
)
f1ps
.
append
(
f1p
)
f2s
.
append
(
f2
)
f2fs
.
append
(
f2f
)
fhs
.
append
(
f_hat
)
return
fs
,
f1cs
,
f1ps
,
f2s
,
f2fs
,
fhs
def
test_ft_rotation2
(
t
=
np
.
pi
/
10
):
nx
=
20
ny
=
20
p0
=
nx
/
2
q0
=
ny
/
2
# Shows that when the image rotates around center (p0, q0), the FT also rotates around (p0, q0) (which corresponds
# to frequency (0, 0).
#f1 = np.zeros((nx, ny), dtype='complex')
#f1[p0 - 3:p0 + 4, q0 - 3:q0+4] = np.random.randn(7, 7) + 1j * np.random.randn(7, 7)
X
,
Y
=
np
.
meshgrid
(
np
.
arange
(
0
,
nx
),
np
.
arange
(
0
,
ny
),
indexing
=
'ij'
)
C
=
np
.
c_
[
X
[...,
None
],
Y
[...,
None
]]
#f1 = np.exp(1j * 2 * np.pi * (5 * X) / nx)
fs
=
[]
fhs
=
[]
for
i
in
range
(
36
):
t
=
i
*
2
*
np
.
pi
/
36.
R
=
np
.
array
([[
np
.
cos
(
-
t
),
-
np
.
sin
(
-
t
)],
[
np
.
sin
(
-
t
),
np
.
cos
(
-
t
)]])
RC
=
np
.
einsum
(
'ij,abj->abi'
,
R
,
C
-
np
.
array
([
p0
,
q0
])[
None
,
None
,
:])
+
np
.
array
([
p0
,
q0
])[
None
,
None
,
:]
#Rfr = map_wrap(f1.real, RC.transpose(2, 0, 1))
#Rfi = map_wrap(f1.imag, RC.transpose(2, 0, 1))
#Rf = Rfr + Rfi * 1j
Rf
=
np
.
exp
(
1j
*
2
*
np
.
pi
*
(
5
*
RC
[...,
0
])
/
nx
)
Rfh
=
shift_fft
(
Rf
)
Rfh
=
T2FFT
.
analyze
(
Rf
)
fs
.
append
(
Rf
)
fhs
.
append
(
Rfh
)
return
fs
,
fhs
def
test_ft_rotation
(
t
=
np
.
pi
/
3.
):
nx
=
500
ny
=
500
p0
=
nx
/
2
q0
=
ny
/
2
# Shows that when the image rotates around center (p0, q0), the FT also rotates around (p0, q0) (which corresponds
# to frequency (0, 0).
f1
=
np
.
zeros
((
nx
,
ny
),
dtype
=
'complex'
)
#f1[p0 - 1, q0 - 1] = 1.
#f1[p0, q0] = 1.
#f1[p0 + 1, q0 + 1] = 1.
#f1[p0 + 2, q0 + 2] = 1.
f1
[
p0
-
3
:
p0
+
4
,
q0
-
3
:
q0
+
4
]
=
np
.
random
.
randn
(
7
,
7
)
+
1j
*
np
.
random
.
randn
(
7
,
7
)
f2
=
np
.
zeros
((
nx
,
ny
))
#f2[p0 - 1, q0 + 1] = 1.
#f2[p0, q0] = 1.
#f2[p0 + 1, q0 - 1] = 1.
#f2[p0 + 2, q0 - 2] = 1.
f2
=
imrot
(
f1
,
t
)
#F = SE2_FFT(spatial_grid_size=(nx, ny, 10),
# interpolation_method='spline',
# spline_order=1,
# oversampling_factor=5)
#f1hp = F.resample_c2p(f1hc)
#f2hp = F.resample_c2p(f2hc)
#f1 = f1[:, :, None] * np.ones(10)[None, None, :]
#f2 = f2[:, :, None] * np.ones(10)[None, None, :]
#f1, f11c, f11p, f12, f12f, f1_hat = F.analyze(f1)
#f2, f21c, f21p, f22, f22f, f2_hat = F.analyze(f2)
#f11c_irot = imrot(f11c[:, :, 0], -t)
f1h
=
shift_fft
(
f1
)
f2h
=
shift_fft
(
f2
)
f1h_rot
=
imrot
(
f1h
,
t
)
f1h_irot
=
imrot
(
f1h
,
-
t
)
return
f1
,
f2
,
f1h
,
f2h
,
f1h_rot
,
f1h_irot
def
cartesian_grid
(
nx
,
ny
):
x
=
np
.
linspace
(
-
0.5
,
0.5
,
nx
,
endpoint
=
False
)
y
=
np
.
linspace
(
-
0.5
,
0.5
,
ny
,
endpoint
=
False
)
X
,
Y
=
np
.
meshgrid
(
x
,
y
,
indexing
=
'ij'
)
return
X
,
Y
class
SE2_FFT
(
FFTBase
):
def
__init__
(
self
,
spatial_grid_size
=
(
10
,
10
,
10
),
interpolation_method
=
'spline'
,
spline_order
=
1
,
oversampling_factor
=
1
):
self
.
spatial_grid_size
=
spatial_grid_size
# tau_x, tau_y, theta
self
.
interpolation_method
=
interpolation_method
if
interpolation_method
==
'spline'
:
self
.
spline_order
=
spline_order
# The array coordinates of the zero-frequency component
self
.
p0
=
spatial_grid_size
[
0
]
//
2
self
.
q0
=
spatial_grid_size
[
1
]
//
2
# The distance, in pixels, from the (0, 0) pixel to the center of frequency space
self
.
r_max
=
np
.
sqrt
(
self
.
p0
**
2
+
self
.
q0
**
2
)
# Precomputation for cartesian-to-polar regridding
self
.
n_samples_r
=
oversampling_factor
*
(
np
.
ceil
(
self
.
r_max
)
+
1
)
self
.
n_samples_t
=
oversampling_factor
*
(
np
.
ceil
(
2
*
np
.
pi
*
self
.
r_max
))
r
=
np
.
linspace
(
0
,
self
.
r_max
,
self
.
n_samples_r
,
endpoint
=
True
)
theta
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
self
.
n_samples_t
,
endpoint
=
False
)
R
,
THETA
,
=
np
.
meshgrid
(
r
,
theta
,
indexing
=
'ij'
)
# Convert polar to Cartesian coordinates
X
=
R
*
np
.
cos
(
THETA
)
Y
=
R
*
np
.
sin
(
THETA
)
# Transform to array indices (note; these are not the usual coordinates where y axis is flipped)
I
=
X
+
self
.
p0
J
=
Y
+
self
.
q0
self
.
c2p_coords
=
np
.
r_
[
I
[
None
,
...],
J
[
None
,
...]]
# Precomputation for polar-to-cartesian regridding
i
=
np
.
arange
(
0
,
self
.
spatial_grid_size
[
0
])
j
=
np
.
arange
(
0
,
self
.
spatial_grid_size
[
1
])
x
=
i
-
self
.
p0
y
=
j
-
self
.
q0
X
,
Y
=
np
.
meshgrid
(
x
,
y
,
indexing
=
'ij'
)
# Convert Cartesian to polar coordinates:
R
=
np
.
sqrt
(
X
**
2
+
Y
**
2
)
T
=
np
.
arctan2
(
Y
,
X
)
# % (2 * np.pi)
# Convert to array indices
# Maximum of R is r_max, maximum index in array is (n_samples_r - 1)
R
*=
(
self
.
n_samples_r
-
1
)
/
self
.
r_max
# The maximum angle in T is arbitrarily close to 2 pi,
# but this should end up 1 pixel past the last index n_samples_t - 1, i.e. it should end up at n_samples_t
# which is equal to index 0 since wraparound is used.
T
*=
self
.
n_samples_t
/
(
2
*
np
.
pi
)
self
.
p2c_coords
=
np
.
r_
[
R
[
None
,
...],
T
[
None
,
...]]
elif
interpolation_method
==
'Fourier'
:
#r_max = np.sqrt(2)
r_max
=
1.
/
np
.
sqrt
(
2.
)
#nr = spatial_grid_size[0] + 1
nr
=
15
*
np
.
ceil
(
r_max
*
spatial_grid_size
[
0
])
nt
=
5
*
np
.
ceil
(
2
*
np
.
pi
*
r_max
*
spatial_grid_size
[
0
])
nx
=
spatial_grid_size
[
0
]
ny
=
spatial_grid_size
[
1
]
self
.
flerp
=
FourierInterpolator
.
init_cartesian_to_polar
(
nr
,
nt
,
nx
,
ny
)
else
:
raise
ValueError
(
'Unknown interpolation method:'
+
str
(
interpolation_method
))
def
analyze
(
self
,
f
):
"""
Compute the SE(2) Fourier Transform of a function f : SE(2) -> C or f : SE(2) -> R.
The SE(2) Fourier Transform expands f in the basis of matrix elements of irreducible representations of SE(2).
Let T^r_pq(g) be the (p, q) matrix element of the irreducible representation of SE(2) of weight / radius r,
then the FT is:
F^r_pq = int_SE(2) f(g) conjugate(T^r_pq(g^{-1})) dg
We assume g in SE(2) to be parameterized as g = (tau_x, tau_y, theta), where tau is a 2D translation vector
and theta is a rotation angle.
The input f is a 3D array of shape (N_x, N_y, N_t),
where the axes correspond to tau_x, tau_y, theta in the ranges:
tau_x in np.linspace(-0.5, 0.5, N_x, endpoint=False)
tau_y in np.linspace(-0.5, 0.5, N_y, endpoint=False)
theta in np.linspace(0, 2 * np.pi, N_t, endpoint=False)
See:
"Engineering Applications of Noncommutative Harmonic Analysis", section 11.2
Chrikjian & Kyatkin
"The Mackey Machine: a user manual"
Taco S. Cohen
:param f: discretely sampled function on SE(2).
The first two axes of f correspond to translation parameters tau_x, tau_y, and the third axis corresponds to
rotation angle theta.
:return: F, the SE(2) Fourier Transform of f. Axes of F are (r, p, q)
"""
# First, FFT along translation parameters tau_1 and tau_2
#f1c_shift = T2FFT.analyze(f, axes=(0, 1))
# This gives: f1c_shift[xi_1, xi_2, theta]
# where xi_1 and xi_2 are Cartesian (c) coordinates of the frequency domain.
# However, this is the FT of the *shifted* function on [0, 1), so shift the coefficient back:
#delta = -0.5 # we're shifting from [0, 1) to [-0.5, 0.5)
#xi1 = np.arange(-np.floor(f1c_shift.shape[0] / 2.), np.ceil(f1c_shift.shape[0] / 2.))
#xi2 = np.arange(-np.floor(f1c_shift.shape[1] / 2.), np.ceil(f1c_shift.shape[1] / 2.))
#XI1, XI2 = np.meshgrid(xi1, xi2, indexing='ij')
#phase = np.exp(-2 * np.pi * 1j * delta * (XI1 + XI2))
#f1c = f1c_shift * phase[:, :, None]
f1c
=
shift_fft
(
f
)
# Change from Cartesian (c) to a polar (p) grid:
f1p
=
self
.
resample_c2p_3d
(
f1c
)
# This gives f1p[r, varphi, theta]
# FFT along rotation angle theta
# We conjugate the argument and the ouput so that the complex exponential has positive instead of negative sign
f2
=
T1FFT
.
analyze
(
f1p
.
conj
(),
axis
=
2
).
conj
()
# This gives f2[r, varphi, q]
# where q ranges from q = -floor(f1p.shape[2] / 2) to q = ceil(f1p.shape[2] / 2) - 1 (inclusive)
# Multiply f2 by a (varphi, q)-dependent phase factor:
m_min
=
-
np
.
floor
(
f2
.
shape
[
2
]
/
2.
)
m_max
=
np
.
ceil
(
f1p
.
shape
[
2
]
/
2.
)
-
1
varphi
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
f2
.
shape
[
1
],
endpoint
=
False
)
# may not need this many points on every circle
factor
=
np
.
exp
(
-
1j
*
varphi
[
None
,
:,
None
]
*
np
.
arange
(
m_min
,
m_max
+
1
)[
None
,
None
,
:])
f2f
=
f2
*
factor
# FFT along polar coordinate of frequency domain
f_hat
=
T1FFT
.
analyze
(
f2f
.
conj
(),
axis
=
1
).
conj
()
# This gives f_hat[r, p, q]
return
f
,
f1c
,
f1p
,
f2
,
f2f
,
f_hat
def
synthesize
(
self
,
f_hat
):
f2f
=
T1FFT
.
synthesize
(
f_hat
.
conj
(),
axis
=
1
).
conj
()
# Multiply f_2 by a phase factor:
m_min
=
-
np
.
floor
(
f2f
.
shape
[
2
]
/
2
)
m_max
=
np
.
ceil
(
f2f
.
shape
[
2
]
/
2
)
-
1
psi
=
np
.
linspace
(
0
,
2
*
np
.
pi
,
f2f
.
shape
[
1
],
endpoint
=
False
)
# may not need this many points on every circle
factor
=
np
.
exp
(
1j
*
psi
[:,
None
]
*
np
.
arange
(
m_min
,
m_max
+
1
)[
None
,
:])
f2
=
f2f
*
factor
[
None
,
...]
f1p
=
T1FFT
.
synthesize
(
f2
.
conj
(),
axis
=
2
).
conj
()
f1c
=
self
.
resample_p2c_3d
(
f1p
)
# delta = -0.5 # we're shifting from [0, 1) to [-0.5, 0.5)
# xi1 = np.arange(-np.floor(f1c.shape[0] / 2), np.ceil(f1c.shape[0] / 2))
# xi2 = np.arange(-np.floor(f1c.shape[1] / 2), np.ceil(f1c.shape[1] / 2))
# XI1, XI2 = np.meshgrid(xi1, xi2, indexing='ij')
# phase = np.exp(-2 * np.pi * 1j * delta * (XI1 + XI2))
# f1c_shift = f1c / phase[:, :, None]
#f = T2FFT.synthesize(f1c, axes=(0, 1))
#f = T2FFT.synthesize(f1c_shift, axes=(0, 1))
f
=
shift_ifft
(
f1c
)
return
f
,
f1c
,
f1p
,
f2
,
f2f
,
f_hat
def
resample_c2p
(
self
,
fc
):
"""
Resample a function on a Cartesian grid to a polar grid.
The center of the Cartesian coordinate system is assumed to be in the center of the image at index
x0 = fc.shape[0] / 2 - 0.5
y0 = fc.shape[1] / 2 - 0.5
i.e. for a 2-pixel image, x0 would be at 'index' 2/2-0.5 = 0.5, in between the two pixels.
The first dimension of the output coresponds to the radius r in [0, r_max=fc.shape[0] / 2. - 0.5]
while the second dimension corresponds to the angle theta in [0, 2pi).
:param fc: function values sampled on a Cartesian grid.
:return: resampled function on a polar grid
"""
# We are dealing with three coordinate frames:
# The array indices / image coordinates (i, j) of the input data array.
# The Cartesian frame (x, y) centered in the image, with the same directions and units (=pixels) on the axes.
# The polar coordinate frame (r, theta), also centered in the image, with theta=0 corresponding to the x axis.
# (x0, y0) are the image coordinates / array indices of the center of the Cartesian coordinate frame
# centered in the image. Note that although they are in the image coordinate frame, they are not necessarily ints.
#fp_r = map_coordinates(fc.real, self.c2p_coords, order=self.spline_order, mode='wrap') # 'nearest')
#fp_c = map_coordinates(fc.imag, self.c2p_coords, order=self.spline_order, mode='wrap') # 'nearest')
#fp = fp_r + 1j * fp_c
fp_r
=
map_wrap
(
fc
.
real
,
self
.
c2p_coords
)
fp_c
=
map_wrap
(
fc
.
imag
,
self
.
c2p_coords
)
fp
=
fp_r
+
1j
*
fp_c
return
fp
def
resample_p2c
(
self
,
fp
):
# , order=1, mode='wrap', cval=np.nan):
fc_r
=
map_coordinates
(
fp
.
real
,
self
.
p2c_coords
,
order
=
self
.
spline_order
,
mode
=
'wrap'
)
fc_c
=
map_coordinates
(
fp
.
imag
,
self
.
p2c_coords
,
order
=
self
.
spline_order
,
mode
=
'wrap'
)
fc
=
fc_r
+
1j
*
fc_c
return
fc
def
resample_c2p_3d
(
self
,
fc
):
if
self
.
interpolation_method
==
'spline'
:
fp
=
[]
for
i
in
range
(
fc
.
shape
[
2
]):
fp
.
append
(
self
.
resample_c2p
(
fc
[:,
:,
i
]))
return
np
.
c_
[
fp
].
transpose
(
1
,
2
,
0
)
elif
self
.
interpolation_method
==
'Fourier'
:
fp
=
[]
for
i
in
range
(
fc
.
shape
[
2
]):
fp
.
append
(
self
.
flerp
.
forward
(
fc
[:,
:,
i
]))
return
np
.
c_
[
fp
].
transpose
(
1
,
2
,
0
)
def
resample_p2c_3d
(
self
,
fp
):
if
self
.
interpolation_method
==
'spline'
:
fc
=
[]
for
i
in
range
(
fp
.
shape
[
2
]):
fc
.
append
(
self
.
resample_p2c
(
fp
[:,
:,
i
]))
return
np
.
c_
[
fc
].
transpose
(
1
,
2
,
0
)
elif
self
.
interpolation_method
==
'Fourier'
:
fc
=
[]
for
i
in
range
(
fp
.
shape
[
2
]):
fc
.
append
(
self
.
flerp
.
backward
(
fp
[:,
:,
i
]))
return
np
.
c_
[
fc
].
transpose
(
1
,
2
,
0
)
def
R2_SE2_FFT
(
f
):
"""
Compute the SE(2) Fourier Transform of f : R^2 -> R as if f is a function on SE(2).
That is, we view f as a function on SE(2):
f+(g) = (L_g f)(0) = f(g^{-1} 0)
If we parameterize g = (r, theta), where r is a translation vector and theta a rotation angle,
we see that f+ is constant along theta, because a rotation of 0 is 0.
Therefore,
"""
pass
lie_learn/lie_learn/spectral/SO3FFT_Naive.py
0 → 100755
View file @
b5881ee2
from
functools
import
lru_cache
import
numpy
as
np
from
lie_learn.representations.SO3.irrep_bases
import
change_of_basis_matrix
from
lie_learn.representations.SO3.pinchon_hoggan.pinchon_hoggan_dense
import
rot_mat
,
Jd
from
lie_learn.representations.SO3.wigner_d
import
wigner_d_matrix
,
wigner_D_matrix
import
lie_learn.spaces.S3
as
S3
from
lie_learn.representations.SO3.indexing
import
flat_ind_zp_so3
,
flat_ind_so3
from
.FFTBase
import
FFTBase
from
scipy.fftpack
import
fft2
,
ifft2
,
fftshift
# TODO:
# Write testing code for these FFTs
# Write fast code for the real, quantum-normalized, centered / block-diagonal bases.
# The real Wigner-d functions d^l_mn are identically 0 whenever either (m < 0 and n >= 0) or (m >= 0 and n < 0),
# so we can save work in the Wigner-d transform
class
SO3_FT_Naive
(
FFTBase
):
"""
The most naive implementation of the discrete SO(3) Fourier transform:
explicitly construct the Fourier matrix F and multiply by it to perform the Fourier transform.
We use the following convention:
Let D^l_mn(g) (the Wigner D function) be normalized so that it is unitary.
FFT(f)^l_mn = int_SO(3) f(g) \conj(D^l_mn(g)) dg
where dg is the normalized Haar measure on SO(3).
IFFT(\hat(f))(g) = \sum_{l=0}^L_max (2l + 1) \sum_{m=-l}^l \sum_{n=-l}^l \hat(f)^l_mn D^l_mn(g)
Under this convention, where (2l+1) appears in the IFFT, we have:
- The Fourier transform of D^l_mn is a one-hot vector where FFT(D^l_mn)^l_mn = 1 / (2l + 1),
because 1 / (2l + 1) is the squared norm of D^l_mn.
- The convolution theorem is
FFT(f * psi) = FFT(f) FFT(psi)^{*T},
i.e. the second argument is conjugate-transposed, and there is no normalization constant required.
"""
def
__init__
(
self
,
L_max
,
field
=
'complex'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
super
().
__init__
()
# TODO allow user to specify the grid (now using SOFT implicitly)
# Explicitly construct the Wigner-D matrices evaluated at each point in a grid in SO(3)
self
.
D
=
[]
b
=
L_max
+
1
for
l
in
range
(
b
):
self
.
D
.
append
(
np
.
zeros
((
2
*
b
,
2
*
b
,
2
*
b
,
2
*
l
+
1
,
2
*
l
+
1
),
dtype
=
complex
if
field
==
'complex'
else
float
))
for
j1
in
range
(
2
*
b
):
alpha
=
2
*
np
.
pi
*
j1
/
(
2.
*
b
)
for
k
in
range
(
2
*
b
):
beta
=
np
.
pi
*
(
2
*
k
+
1
)
/
(
4.
*
b
)
for
j2
in
range
(
2
*
b
):
gamma
=
2
*
np
.
pi
*
j2
/
(
2.
*
b
)
self
.
D
[
-
1
][
j1
,
k
,
j2
,
:,
:]
=
wigner_D_matrix
(
l
,
alpha
,
beta
,
gamma
,
field
,
normalization
,
order
,
condon_shortley
)
# Compute quadrature weights
self
.
w
=
S3
.
quadrature_weights
(
b
=
b
,
grid_type
=
'SOFT'
)
# Stack D into a single Fourier matrix
# The first axis corresponds to the spatial samples.
# The spatial grid has shape (2b, 2b, 2b), so this axis has length (2b)^3.
# The second axis of this matrix has length sum_{l=0}^L_max (2l+1)^2,
# which corresponds to all the spectral coefficients flattened into a vector.
# (normally these are stored as matrices D^l of shape (2l+1)x(2l+1))
self
.
F
=
np
.
hstack
([
self
.
D
[
l
].
reshape
((
2
*
b
)
**
3
,
(
2
*
l
+
1
)
**
2
)
for
l
in
range
(
b
)])
# For the IFFT / synthesis transform, we need to weight the order-l Fourier coefficients by (2l + 1)
# Here we precompute these coefficients.
ls
=
[[
ls
]
*
(
2
*
ls
+
1
)
**
2
for
ls
in
range
(
b
)]
ls
=
np
.
array
([
ll
for
sublist
in
ls
for
ll
in
sublist
])
# (0,) + 9 * (1,) + 25 * (2,), ...
self
.
l_weights
=
2
*
ls
+
1
def
analyze
(
self
,
f
):
f_hat
=
[]
for
l
in
range
(
f
.
shape
[
0
]
//
2
):
f_hat
.
append
(
np
.
einsum
(
'ijkmn,ijk->mn'
,
self
.
D
[
l
],
f
*
self
.
w
[
None
,
:,
None
]))
return
f_hat
def
analyze_by_matmul
(
self
,
f
):
f
=
f
*
self
.
w
[
None
,
:,
None
]
f
=
f
.
flatten
()
return
self
.
F
.
T
.
conj
().
dot
(
f
)
def
synthesize
(
self
,
f_hat
):
b
=
len
(
self
.
D
)
f
=
np
.
zeros
((
2
*
b
,
2
*
b
,
2
*
b
),
dtype
=
self
.
D
[
0
].
dtype
)
for
l
in
range
(
b
):
f
+=
np
.
einsum
(
'ijkmn,mn->ijk'
,
self
.
D
[
l
],
f_hat
[
l
]
*
(
2
*
l
+
1
))
return
f
def
synthesize_by_matmul
(
self
,
f_hat
):
return
self
.
F
.
dot
(
f_hat
*
self
.
l_weights
)
class
SO3_FFT_SemiNaive_Complex
(
FFTBase
):
def
__init__
(
self
,
L_max
,
d
=
None
,
w
=
None
,
L2_normalized
=
True
,
field
=
'complex'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
super
().
__init__
()
if
d
is
None
:
self
.
d
=
setup_d_transform
(
b
=
L_max
+
1
,
L2_normalized
=
L2_normalized
,
field
=
field
,
normalization
=
normalization
,
order
=
order
,
condon_shortley
=
condon_shortley
)
else
:
self
.
d
=
d
if
w
is
None
:
self
.
w
=
S3
.
quadrature_weights
(
b
=
L_max
+
1
)
else
:
self
.
w
=
w
self
.
wd
=
weigh_wigner_d
(
self
.
d
,
self
.
w
)
def
analyze
(
self
,
f
):
return
SO3_FFT_analyze
(
f
)
# , self.wd)
def
synthesize
(
self
,
f_hat
):
"""
Perform the inverse (spectral to spatial) SO(3) Fourier transform.
:param f_hat: a list of matrices of with shapes [1x1, 3x3, 5x5, ..., 2 L_max + 1 x 2 L_max + 1]
"""
return
SO3_FFT_synthesize
(
f_hat
)
# , self.d)
class
SO3_FFT_NaiveReal
(
FFTBase
):
def
__init__
(
self
,
L_max
,
d
=
None
,
L2_normalized
=
True
):
self
.
L_max
=
L_max
self
.
complex_fft
=
SO3_FFT_SemiNaive_Complex
(
L_max
=
L_max
,
d
=
d
,
L2_normalized
=
L2_normalized
)
# Compute change of basis function:
self
.
c2b
=
[
change_of_basis_matrix
(
l
,
frm
=
(
'complex'
,
'seismology'
,
'centered'
,
'cs'
),
to
=
(
'real'
,
'quantum'
,
'centered'
,
'cs'
))
for
l
in
range
(
L_max
+
1
)]
def
analyze
(
self
,
f
):
raise
NotImplementedError
(
'SO3 analyze function not implemented yet'
)
def
synthesize
(
self
,
f_hat
):
"""
"""
# Change basis on f_hat
# We have R = B * C * B.conj().T, where
# B is the real-to-complex change of basis, C are the complex Wigner D functions,
# and R are the real Wigner D functions.
# We want to compute Tr(eta^T R) = Tr( (B.T * eta * B.conj())^T C)
f_hat_complex
=
[
self
.
c2b
[
l
].
T
.
dot
(
f_hat
[
l
]).
dot
(
self
.
c2b
[
l
].
conj
())
for
l
in
range
(
self
.
L_max
+
1
)]
f
=
self
.
complex_fft
.
synthesize
(
f_hat_complex
)
return
f
.
real
def
synthesize_direct
(
self
,
f_hat
):
pass
# Synthesize without using complex fft
def
SO3_FFT_analyze
(
f
):
"""
Compute the complex SO(3) Fourier transform of f.
The standard way to define the FT is:
\hat{f}^l_mn = (2 J + 1)/(8 pi^2) *
int_0^2pi da int_0^pi db sin(b) int_0^2pi dc f(a,b,c) D^{l*}_mn(a,b,c)
The normalizing constant comes about because:
int_SO(3) D^*(g) D(g) dg = 8 pi^2 / (2 J + 1)
where D is any Wigner D function D^l_mn. Note that the factor 8 pi^2 (the volume of SO(3))
goes away if we integrate with the normalized Haar measure.
This function computes the FT using the normalized D functions:
\t
ilde{D} = 1/2pi sqrt((2J+1)/2) D
where D are the rotation matrices in the basis of complex, seismology-normalized, centered spherical harmonics.
Hence, this function computes:
\hat{f}^l_mn = \int_SO(3) f(g)
\t
ilde{D}^{l*}_mn(g) dg
So that the FT of f =
\t
ilde{D}^l_mn is 1 at (l,m,n) (and zero elsewhere).
Args:
f: an array of shape (2B, 2B, 2B), where B is the bandwidth.
Returns:
f_hat: the Fourier transform of f. A list of length B,
where entry l contains an 2l+1 by 2l+1 array containing the projections
of f onto matrix elements of the l-th irreducible representation of SO(3).
Main source:
SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
Further information:
Generalized FFTs-a survey of some recent results
Maslen & Rockmore
Engineering Applications of Noncommutative Harmonic Analysis.
9.5 - Sampling and FFT for SO(3) and SU(2)
G.S. Chrikjian, A.B. Kyatkin
"""
assert
f
.
shape
[
0
]
==
f
.
shape
[
1
]
assert
f
.
shape
[
1
]
==
f
.
shape
[
2
]
assert
f
.
shape
[
0
]
%
2
==
0
# First, FFT along the alpha and gamma axes (axis 0 and 2, respectively)
F
=
fft2
(
f
,
axes
=
(
0
,
2
))
F
=
fftshift
(
F
,
axes
=
(
0
,
2
))
# Then, perform the Wigner-d transform
return
wigner_d_transform_analysis
(
F
)
def
SO3_FFT_synthesize
(
f_hat
):
"""
Perform the inverse (spectral to spatial) SO(3) Fourier transform.
:param f_hat: a list of matrices of with shapes [1x1, 3x3, 5x5, ..., 2 L_max + 1 x 2 L_max + 1]
"""
F
=
wigner_d_transform_synthesis
(
f_hat
)
# The rest of the SO(3) FFT is just a standard torus FFT
F
=
fftshift
(
F
,
axes
=
(
0
,
2
))
f
=
ifft2
(
F
,
axes
=
(
0
,
2
))
b
=
len
(
f_hat
)
return
f
*
(
2
*
b
)
**
2
def
SO3_ifft
(
f_hat
):
"""
"""
b
=
len
(
f_hat
)
d
=
setup_d_transform
(
b
)
df_hat
=
[
d
[
l
]
*
f_hat
[
l
][:,
None
,
:]
for
l
in
range
(
len
(
d
))]
# Note: the frequencies where m=-B or n=-B are set to zero,
# because they are not used in the forward transform either
# (the forward transform is up to m=-l, l<B
F
=
np
.
zeros
((
2
*
b
,
2
*
b
,
2
*
b
),
dtype
=
complex
)
for
l
in
range
(
b
):
F
[
b
-
l
:
b
+
l
+
1
,
:,
b
-
l
:
b
+
l
+
1
]
+=
df_hat
[
l
]
F
=
fftshift
(
F
,
axes
=
(
0
,
2
))
f
=
ifft2
(
F
,
axes
=
(
0
,
2
))
return
f
*
2
*
(
b
**
2
)
/
np
.
pi
def
wigner_d_transform_analysis
(
f
):
"""
The discrete Wigner-d transform [1] is defined as
WdT(s)[l, m, n] = sum_k=0^{2b-1} w_b(k) d^l_mn(beta_k) s_k
where:
- w_b(k) is the k-th quadrature weight for an order b grid,
- d^l_mn is a Wigner-d function,
- beta_k = pi(2k + 1) / 4b
- s is a data vector of length 2b
In practice we want to transform many data vectors at once; we have an input array of shape (2b, 2b, 2b)
[1] SOFT: SO(3) Fourier Transforms
Peter J. Kostelec and Daniel N. Rockmore
:param F:
:param wd: the weighted Wigner-d functions, as returned by weigh_wigner_d()
:return:
"""
assert
f
.
shape
[
0
]
==
f
.
shape
[
1
]
assert
f
.
shape
[
1
]
==
f
.
shape
[
2
]
assert
f
.
shape
[
0
]
%
2
==
0
b
=
f
.
shape
[
0
]
//
2
# The bandwidth
f0
=
f
.
shape
[
0
]
//
2
# The index of the 0-frequency / DC component
wd
=
weighted_d
(
b
)
f_hat
=
[]
# To store the result
Z
=
2
*
np
.
pi
/
((
2
*
b
)
**
2
)
# Normalizing constant
# NOTE: the factor 1. / (2 (2b)^2) comes from the quadrature integration - see S3.integrate_quad
# Maybe it makes more sense to integrate this factor into the quadrature weights.
# The factor 4 pi is probably related to the normalization of the Haar measure on S^2
# The array F we have computed so far still has shape (2b, 2b, 2b),
# where the axes correspond to (M, beta, M').
# For each l = 0, ..., b-1, select a subarray of shape (2l + 1, 2b, 2l + 1)
f_sub
=
[
f
[
f0
-
l
:
f0
+
l
+
1
,
:,
f0
-
l
:
f0
+
l
+
1
]
for
l
in
range
(
b
)]
for
l
in
range
(
b
):
# Dot the vectors F_mn and d_mn over the middle axis (beta),
# where -l <= m,n <= l, which corresponds to
# f0 - l <= m,n < f0 + l + 1
# for 0-based indexing and zero-frequency location f0
f_hat
.
append
(
np
.
einsum
(
'mbn,mbn->mn'
,
wd
[
l
],
f_sub
[
l
])
*
Z
)
return
f_hat
def
wigner_d_transform_analysis_vectorized
(
f
,
wd_flat
,
idxs
):
""" computes the wigner transform analysis in a vectorized way
returns the flattened blocks of f_hat as a single vector
f: the input signal, shape (2b, 2b, 2b) axes m, beta, n.
wd_flat: the flattened weighted wigner d functions, shape (num_spectral, 2b), axes (l*m*n, beta)
idxs: the array of indices containing all analysis blocks
"""
f_trans
=
f
.
transpose
([
0
,
2
,
1
])
# shape 2b, 2b, 2b, axes m, n, beta
f_trans_flat
=
f_trans
.
reshape
(
-
1
,
f
.
shape
[
1
])
# shape 4b^2, 2b, axes m*n, beta
f_i
=
f_trans_flat
[
idxs
]
# shape num_spectral, 2b, axes l*m*n, beta
prod
=
f_i
*
wd_flat
# shape num_spectral, 2b, axes l*m*n, beta
result
=
prod
.
sum
(
axis
=
1
)
# shape num_spectral, axes l*m*n
return
result
def
wigner_d_transform_analysis_vectorized_v2
(
f
,
wd_flat_t
,
idxs
):
"""
:param f: the SO(3) signal, shape (2b, 2b, 2b), axes beta, m, n
:param wd_flat: the flattened weighted wigner d functions, shape (2b, num_spectral), axes (beta, l*m*n)
:param idxs:
:return:
"""
fr
=
f
.
reshape
(
f
.
shape
[
0
],
-
1
)
# shape 2b, 4b^2, axes beta, m*n
f_i
=
fr
[...,
idxs
]
# shape 2b, num_spectral, axes beta, l*m*n
prod
=
f_i
*
wd_flat_t
# shape 2b, num_spectral, axes beta, l*m*n
result
=
prod
.
sum
(
axis
=
0
)
# shape num_spectral, axes l*m*n
return
result
def
wigner_d_transform_synthesis
(
f_hat
):
b
=
len
(
f_hat
)
d
=
setup_d_transform
(
b
,
L2_normalized
=
False
)
# Perform the brute-force Wigner-d transform
# Note: the frequencies where m=-B or n=-B are set to zero,
# because they are not used in the forward transform either
# (the forward transform is up to m=-l, l<B
df_hat
=
[
d
[
l
]
*
f_hat
[
l
][:,
None
,
:]
for
l
in
range
(
b
)]
F
=
np
.
zeros
((
2
*
b
,
2
*
b
,
2
*
b
),
dtype
=
complex
)
for
l
in
range
(
b
):
F
[
b
-
l
:
b
+
l
+
1
,
:,
b
-
l
:
b
+
l
+
1
]
+=
df_hat
[
l
]
return
F
def
wigner_d_transform_synthesis_vectorized
(
f_hat_flat
,
b
):
dv
=
vectorized_d
(
b
)
inds
=
zero_padding_inds
(
b
)
f_hat_vec
=
f_hat_flat
[
inds
]
f_hat_vec
=
f_hat_vec
.
reshape
(
b
,
2
*
b
,
1
,
2
*
b
)
return
np
.
einsum
(
'lmbn,lmbn->mbn'
,
f_hat_vec
,
dv
)
def
vectorize_d
(
d
):
"""
In order to write the Wigner-d synthesis transform in a vectorized manner, we need to create a tensor of
Wigner-d function evaluations with special padding.
:param d:
:return:
"""
b
=
len
(
d
)
# Create a dense tensor with axes for l, m, beta, n
dv
=
np
.
zeros
((
b
,
2
*
b
,
2
*
b
,
2
*
b
))
for
l
in
range
(
b
):
dv
[
l
][
b
-
l
:
b
+
l
+
1
,
:,
b
-
l
:
b
+
l
+
1
]
=
d
[
l
]
return
dv
def
weigh_wigner_d
(
d
,
w
):
"""
The Wigner-d transform involves a sum where each term is a product of data, d-function, and quadrature weight.
Since the d-functions and quadrature weights don't depend on the data, we can precompute their product.
We have a quadrature weight for each value of beta and beta corresponds to the second axis of d,
so the weights are broadcast over the other axes.
:param d: a list of samples of the Wigner-d function, as returned by setup_d_transform
:param w: an array of quadrature weights, as returned by S3.quadrature_weights
:return: the weighted d function samples, with the same shape as d
"""
return
[
d
[
l
]
*
w
[
None
,
:,
None
]
for
l
in
range
(
len
(
d
))]
@
lru_cache
(
maxsize
=
32
)
def
vectorized_d
(
b
):
d
=
setup_d_transform
(
b
,
L2_normalized
=
False
)
return
vectorize_d
(
d
)
@
lru_cache
(
maxsize
=
32
)
def
zero_padding_inds
(
b
):
"""
To vectorize the Wigner-d transform, we have to take a list of matrices f_hat = [f_hat^0, ..., f_hat^L],
where f_hat^l has shape (2l+1, 2l+1), and flatten it into a vector.
Then we turn turn it into a single array F of shape (b, 2b, 2b) with axes l, m, n.
The (2b, 2b) matrix F[l] has non-zeros in the (2l+1, 2l+1) center.
To implement the latter operation, we need indices. These are computed by this function.
:param b: bandwidth
:return: index array
"""
inds
=
np
.
zeros
(
b
*
2
*
b
*
2
*
b
,
dtype
=
np
.
int
)
for
l
in
range
(
b
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
inds
[
flat_ind_zp_so3
(
l
,
m
,
n
,
b
)]
=
flat_ind_so3
(
l
,
m
,
n
)
return
inds
@
lru_cache
(
maxsize
=
32
)
def
setup_d_transform
(
b
,
L2_normalized
,
field
=
'complex'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
):
"""
Precompute arrays of samples from the Wigner-d function, for use in the Wigner-d transform.
Specifically, the samples that are required are:
d^l_mn(beta_k)
for:
l = 0, ..., b - 1
-l <= m, n <= l
k = 0, ..., 2b - 1
(where beta_k = pi (2 b + 1) / 4b)
This data is returned as a list d indexed by l (of length b),
where each element of the list is an array d[l] of shape (2l+1, 2b, 2l+1) indexed by (m, k, n)
In the Wigner-d transform, for each l, we reduce an array d[l] of shape (2l+1, 2b, 2l+1)
against a data array of the same shape, along the beta axis (axis 1 of length 2b).
:param b: bandwidth of the transform
:param L2_normalized: whether to use L2_normalized versions of the Wigner-d functions.
:param field, normalization, order, condon_shortley: the basis and normalization convention (see irrep_bases.py)
:return a list d of length b, where d[l] is an array of shape (2l+1, 2b, 2l+1)
"""
# Compute array of beta values as described in SOFT 2.0 documentation
beta
=
np
.
pi
*
(
2
*
np
.
arange
(
0
,
2
*
b
)
+
1
)
/
(
4.
*
b
)
# For each l=0, ..., b-1, we compute a 3D tensor of shape (2l+1, 2b, 2l+1) for axes (m, beta, n)
# Together, these indices (l, m, beta, n) identify d^l_mn(beta)
convention
=
{
'field'
:
field
,
'normalization'
:
normalization
,
'order'
:
order
,
'condon_shortley'
:
condon_shortley
}
d
=
[
np
.
array
([
wigner_d_matrix
(
l
,
bt
,
**
convention
)
for
bt
in
beta
]).
transpose
(
1
,
0
,
2
)
for
l
in
range
(
b
)]
if
L2_normalized
:
# TODO: this should be integrated in the normalization spec above, no?
# The Unitary matrix elements have norm:
# | U^\lambda_mn |^2 = 1/(2l+1)
# where the 2-norm is defined in terms of normalized Haar measure.
# So T = sqrt(2l + 1) U are L2-normalized functions
d
=
[
d
[
l
]
*
np
.
sqrt
(
2
*
l
+
1
)
for
l
in
range
(
len
(
d
))]
# We want the L2 normalized functions:
# d = [d[l] * np.sqrt(l + 0.5) for l in range(len(d))]
return
d
@
lru_cache
(
maxsize
=
32
)
def
weighted_d
(
b
):
d
=
setup_d_transform
(
b
,
L2_normalized
=
False
)
w
=
S3
.
quadrature_weights
(
b
,
grid_type
=
'SOFT'
)
return
weigh_wigner_d
(
d
,
w
)
def
get_wigner_analysis_sub_block_indices
(
b
,
l
):
""" computes the indices for the sub-block at order l
used in the wigner analysis """
L
=
2
*
l
+
1
n_cols
=
2
*
b
offset
=
b
-
l
tiles
=
np
.
tile
(
np
.
arange
(
L
),
L
).
reshape
(
L
,
L
)
+
offset
row_offset
=
n_cols
*
(
np
.
arange
(
L
)[:,
None
]
+
offset
)
return
tiles
+
row_offset
def
get_wigner_analysis_block_indices
(
b
):
""" computes the flattened vector of all indices of the sub-blocks
up to order b, used in the wigner analysis"""
return
np
.
concatenate
([
get_wigner_analysis_sub_block_indices
(
b
,
l
).
reshape
(
-
1
)
for
l
in
range
(
b
)])
def
get_wigner_analysis_indices
(
b
):
def
mn_ind_fftshift
(
m
,
n
):
m_zero_based
=
m
+
b
n_zero_based
=
n
+
b
array_height
=
2
*
b
return
m_zero_based
*
array_height
+
n_zero_based
def
mn_ind
(
m
,
n
):
m_zero_based
=
m
%
(
2
*
b
)
n_zero_based
=
n
%
(
2
*
b
)
array_height
=
2
*
b
return
m_zero_based
*
array_height
+
n_zero_based
num_spectral_coefficients
=
np
.
sum
([(
2
*
l
+
1
)
**
2
for
l
in
range
(
b
)])
inds
=
np
.
empty
(
num_spectral_coefficients
,
dtype
=
int
)
for
l
in
range
(
b
):
for
m
in
range
(
-
l
,
l
+
1
):
for
n
in
range
(
-
l
,
l
+
1
):
inds
[
flat_ind_so3
(
l
,
m
,
n
)]
=
mn_ind
(
m
,
n
)
return
inds
def
get_flattened_weighted_ds
(
wd
):
""" flattens the weighted d matrices into one vector """
return
np
.
concatenate
([
m
.
transpose
(
0
,
2
,
1
).
reshape
(
-
1
,
m
.
shape
[
1
])
for
m
in
wd
])
# TODO update these
def
SO3_convolve
(
f
,
g
,
dw
=
None
,
d
=
None
):
assert
f
.
shape
==
g
.
shape
assert
f
.
shape
[
0
]
%
2
==
0
b
=
f
.
shape
[
0
]
/
2
if
d
is
None
:
d
=
setup_d_transform
(
b
)
# To convolve, first perform a Fourier transform on f and g:
F
=
SO3_fft
(
f
,
dw
)
G
=
SO3_fft
(
g
,
dw
)
# The Fourier transform of the convolution f*g is the matrix product FG
# of their Fourier transforms F and G:
FG
=
[
np
.
dot
(
a
,
b
)
for
(
a
,
b
)
in
zip
(
F
,
G
)]
# The convolution is obtain by inverse Fourier transforming FG:
return
SO3_ifft
(
FG
,
d
)
def
SO3_convolve_complex_fft
(
f
,
g
):
assert
f
.
shape
==
g
.
shape
assert
f
.
shape
[
0
]
%
2
==
0
b
=
f
.
shape
[
0
]
/
2
fft
=
SO3_FFT_SemiNaive_Complex
(
L_max
=
b
-
1
,
d
=
None
,
w
=
None
,
L2_normalized
=
False
)
f_hat
=
fft
.
analyze
(
f
)
g_hat
=
fft
.
analyze
(
g
)
fg_hat
=
[
np
.
dot
(
a
,
b
)
for
(
a
,
b
)
in
zip
(
f_hat
,
g_hat
)]
return
fft
.
synthesize
(
fg_hat
)
lie_learn/lie_learn/spectral/SO3_conv.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
lie_learn.spectral.SO3FFT_Naive
import
SO3_FT_Naive
def
conv_test
():
"""
Compute the convolution of two functions on SO(3).
Let f1 : SO(3) -> R and f2 : SO(3) -> R, then the convolution is defined as
f1 * f2(g) = int_{SO(3)} f1(h) f2(g^{-1} h) dh,
where g in SO(3) and dh is the normalized Haar measure on SO(3).
The convolution is computed by a Fourier transform.
It can be shown that the SO(3) Fourier transform of the convolution f1 * f2 is equal to the matrix product
of the SO(3) Fourier transforms of f1 and f2.
For more details, see the note on "Convolution on S^2 and SO(3)"
:return:
"""
from
lie_learn.spectral.SO3FFT_Naive
import
SO3_FT_Naive
b
=
10
f1
=
np
.
ones
((
2
*
b
+
2
,
b
+
1
))
#TODO
f2
=
np
.
ones
((
2
*
b
+
2
,
b
+
1
))
s2_fft
=
S2_FT_Naive
(
L_max
=
b
-
1
,
grid_type
=
'Gauss-Legendre'
,
field
=
'real'
,
normalization
=
'quantum'
,
condon_shortley
=
'cs'
)
so3_fft
=
SO3_FT_Naive
(
L_max
=
b
-
1
,
field
=
'real'
,
normalization
=
'quantum'
,
order
=
'centered'
,
condon_shortley
=
'cs'
)
# Spherical Fourier transform
f1_hat
=
s2_fft
.
analyze
(
f1
)
f2_hat
=
s2_fft
.
analyze
(
f2
)
# Perform block-wise outer product
f12_hat
=
[]
for
l
in
range
(
b
):
f1_hat_l
=
f1_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f2_hat_l
=
f2_hat
[
l
**
2
:
l
**
2
+
2
*
l
+
1
]
f12_hat_l
=
f1_hat_l
[:,
None
]
*
f2_hat_l
[
None
,
:].
conj
()
f12_hat
.
append
(
f12_hat_l
)
# Inverse SO(3) Fourier transform
f12
=
so3_fft
.
synthesize
(
f12_hat
)
return
f12
def
SO3_convolve
(
f
,
g
,
dw
=
None
,
d
=
None
):
assert
f
.
shape
==
g
.
shape
assert
f
.
shape
[
0
]
%
2
==
0
b
=
f
.
shape
[
0
]
/
2
if
d
is
None
:
d
=
setup_d_transform
(
b
)
# To convolve, first perform a Fourier transform on f and g:
F
=
SO3_fft
(
f
,
dw
)
G
=
SO3_fft
(
g
,
dw
)
# The Fourier transform of the convolution f*g is the matrix product FG
# of their Fourier transforms F and G:
FG
=
[
np
.
dot
(
a
,
b
)
for
(
a
,
b
)
in
zip
(
F
,
G
)]
# The convolution is obtain by inverse Fourier transforming FG:
return
SO3_ifft
(
FG
,
d
)
\ No newline at end of file
lie_learn/lie_learn/spectral/T1FFT.py
0 → 100755
View file @
b5881ee2
import
numpy
as
np
from
numpy.fft
import
fft
,
ifft
,
fftshift
,
ifftshift
from
.FFTBase
import
FFTBase
class
T1FFT
(
FFTBase
):
"""
The Fast Fourier Transform on the Circle / 1-Torus / 1-Sphere.
"""
@
staticmethod
def
analyze
(
f
,
axis
=
0
):
"""
Compute the Fourier Transform of the discretely sampled function f : T^1 -> C.
Let f : T^1 -> C be a band-limited function on the circle.
The samples f(theta_k) correspond to points on a regular grid on the circle, as returned by spaces.T1.linspace:
theta_k = 2 pi k / N
for k = 0, ..., N - 1
This function computes
\hat{f}_n = (1/N) \sum_{k=0}^{N-1} f(theta_k) e^{-i n theta_k}
which, if f has band-limit less than N, is equal to:
\hat{f}_n = \int_0^{2pi} f(theta) e^{-i n theta} dtheta / 2pi,
= <f(theta), e^{i n theta}>
where dtheta / 2pi is the normalized Haar measure on T^1, and < , > denotes the inner product on Hilbert space,
with respect to which this transform is unitary.
The range of frequencies n is -floor(N/2) <= n <= ceil(N/2) - 1
:param f:
:param axis:
:return:
"""
# The numpy FFT returns coefficients in a different order than we want them,
# and using a different normalization.
fhat
=
fft
(
f
,
axis
=
axis
)
fhat
=
fftshift
(
fhat
,
axes
=
axis
)
return
fhat
/
f
.
shape
[
axis
]
@
staticmethod
def
synthesize
(
f_hat
,
axis
=
0
):
"""
Compute the inverse / synthesis Fourier transform of the function f_hat : Z -> C.
The function f_hat(n) is sampled at points in a limited range -floor(N/2) <= n <= ceil(N/2) - 1
This function returns
f[k] = f(theta_k) = sum_{n=-floor(N/2)}^{ceil(N/2)-1} f_hat(n) exp(i n theta_k)
where theta_k = 2 pi k / N
for k = 0, ..., N - 1
:param f_hat:
:param axis:
:return:
"""
f_hat
=
ifftshift
(
f_hat
*
f_hat
.
shape
[
axis
],
axes
=
axis
)
f
=
ifft
(
f_hat
,
axis
=
axis
)
return
f
@
staticmethod
def
analyze_naive
(
f
):
f_hat
=
np
.
zeros_like
(
f
)
for
n
in
range
(
f
.
size
):
for
k
in
range
(
f
.
size
):
theta_k
=
k
*
2
*
np
.
pi
/
f
.
size
f_hat
[
n
]
+=
f
[
k
]
*
np
.
exp
(
-
1j
*
n
*
theta_k
)
return
fftshift
(
f_hat
/
f
.
size
,
axes
=
0
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment