basis_utils.py 6.05 KB
Newer Older
1
2
3
import numpy as np
import sympy as sym
from scipy import special as sp
4
5
from scipy.optimize import brentq

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

def Jn(r, n):
    """
    r: int or list
    n: int or list
    len(r) == len(n)
    return value should be the same shape as the input data
    ===
    example:
        r = n = np.array([1, 2, 3, 4])
        res = [0.3, 0.1, 0.1, 0.1]
    ===
    numerical spherical bessel functions of order n
    """
    return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)  # the same shape as n

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def Jn_zeros(n, k):
    """
    n: int
    k: int
    res: array of shape [n, k]

    Compute the first k zeros of the spherical bessel functions up to order n (excluded)
    """
    zerosj = np.zeros((n, k), dtype="float32")
    zerosj[0] = np.arange(1, k + 1) * np.pi
    points = np.arange(1, k + n) * np.pi
    racines = np.zeros(k + n - 1, dtype="float32")
    for i in range(1, n):
        for j in range(k + n - 1 - i):
            foo = brentq(Jn, points[j], points[j + 1], (i,))
            racines[j] = foo
        points = racines
        zerosj[i][:k] = racines[:k]

    return zerosj

44

45
46
47
48
49
50
51
52
def spherical_bessel_formulas(n):
    """
    n: int
    res: array of shape [n,]

    n sympy functions
    Computes the sympy formulas for the spherical bessel functions up to order n (excluded)
    """
53
    x = sym.symbols("x")
54
55
56
57
58
59
60
61
62

    f = [sym.sin(x) / x]
    a = sym.sin(x) / x
    for i in range(1, n):
        b = sym.diff(a, x) / x
        f += [sym.simplify(b * (-x) ** i)]
        a = sym.simplify(b)
    return f

63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def bessel_basis(n, k):
    """
    n: int
    k: int
    res: [n, k]

    n * k sympy functions
    Computes the sympy formulas for the normalized and rescaled spherical bessel functions up to
    order n (excluded) and maximum frequency k (excluded).
    """

    zeros = Jn_zeros(n, k)
    normalizer = []
    for order in range(n):
        normalizer_tmp = []
        for i in range(k):
            normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2]
        normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5
        normalizer += [normalizer_tmp]

    f = spherical_bessel_formulas(n)
85
    x = sym.symbols("x")
86
87
88
89
    bess_basis = []
    for order in range(n):
        bess_basis_tmp = []
        for i in range(k):
90
91
92
93
94
            bess_basis_tmp += [
                sym.simplify(
                    normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)
                )
            ]
95
96
97
        bess_basis += [bess_basis_tmp]
    return bess_basis

98

99
100
101
102
def sph_harm_prefactor(l, m):
    """
    l: int
    m: int
103
    res: float
104
105
106
107
108
    Computes the constant pre-factor for the spherical harmonic of degree l and order m
    input:
    l: int, l>=0
    m: int, -l<=m<=l
    """
109
110
111
112
113
114
    return (
        (2 * l + 1)
        * np.math.factorial(l - abs(m))
        / (4 * np.pi * np.math.factorial(l + abs(m)))
    ) ** 0.5

115
116
117
118
119
120
121

def associated_legendre_polynomials(l, zero_m_only=True):
    """
    l: int
    return: l sympy functions
    Computes sympy formulas of the associated legendre polynomials up to order l (excluded).
    """
122
    z = sym.symbols("z")
123
124
125
126
127
128
129
130
131
    P_l_m = [[0] * (j + 1) for j in range(l)]

    P_l_m[0][0] = 1

    if l > 0:
        P_l_m[1][0] = z

        for j in range(2, l):
            P_l_m[j][0] = sym.simplify(
132
133
134
135
                ((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0])
                / j
            )

136
137
138
139
        if not zero_m_only:
            for i in range(1, l):
                P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
                if i + 1 < l:
140
141
142
                    P_l_m[i + 1][i] = sym.simplify(
                        (2 * i + 1) * z * P_l_m[i][i]
                    )
143
                for j in range(i + 2, l):
144
145
146
147
148
149
150
                    P_l_m[j][i] = sym.simplify(
                        (
                            (2 * j - 1) * z * P_l_m[j - 1][i]
                            - (i + j - 1) * P_l_m[j - 2][i]
                        )
                        / (j - i)
                    )
151
152
153

    return P_l_m

154

155
156
157
158
159
160
161
162
163
164
def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True):
    """
    return: a sympy function list of length l, for i-th index of the list, it is also a list of length (2 * i + 1)
    Computes formula strings of the real part of the spherical harmonics up to order l (excluded).
    Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta.
    """
    if not zero_m_only:
        S_m = [0]
        C_m = [1]
        for i in range(1, l):
165
166
            x = sym.symbols("x")
            y = sym.symbols("y")
167
168
169
170
171
172
            S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
            C_m += [x * C_m[i - 1] - y * S_m[i - 1]]

    P_l_m = associated_legendre_polynomials(l, zero_m_only)

    if spherical_coordinates:
173
174
        theta = sym.symbols("theta")
        z = sym.symbols("z")
175
176
177
178
179

        for i in range(len(P_l_m)):
            for j in range(len(P_l_m[i])):
                if type(P_l_m[i][j]) != int:
                    P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
180

181
        if not zero_m_only:
182
            phi = sym.symbols("phi")
183
            for i in range(len(S_m)):
184
185
186
187
188
                S_m[i] = (
                    S_m[i]
                    .subs(x, sym.sin(theta) * sym.cos(phi))
                    .subs(y, sym.sin(theta) * sym.sin(phi))
                )
189
            for i in range(len(C_m)):
190
191
192
193
194
                C_m[i] = (
                    C_m[i]
                    .subs(x, sym.sin(theta) * sym.cos(phi))
                    .subs(y, sym.sin(theta) * sym.sin(phi))
                )
195

196
    Y_func_l_m = [["0"] * (2 * j + 1) for j in range(l)]
197
198
199
200
201
202
203

    for i in range(l):
        Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])

    if not zero_m_only:
        for i in range(1, l):
            for j in range(1, i + 1):
204
205
206
                Y_func_l_m[i][j] = sym.simplify(
                    2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]
                )
207
208
        for i in range(1, l):
            for j in range(1, i + 1):
209
210
211
                Y_func_l_m[i][-j] = sym.simplify(
                    2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]
                )
212

213
    return Y_func_l_m