_sph_harm.py 1.79 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
The source code here is an adaptation with minimal changes from the following
SciPy Cython file:

https://github.com/scipy/scipy/blob/master/scipy/special/sph_harm.pxd
"""

from cupy import _core

from cupyx.scipy.special._poch import poch_definition
from cupyx.scipy.special._lpmv import lpmv_definition

sph_harmonic_definition = (
    poch_definition
    + lpmv_definition
    + """

#include <cupy/complex.cuh>

// include for CUDART_NAN, CUDART_INF
#include <cupy/math_constants.h>

#define NPY_PI        3.141592653589793238462643383279502884  /* pi */

// from scipy/special/sph_harm.pxd
__device__ complex<double> sph_harmonic(int m, int n, double theta, double phi)
{
    double x, prefactor;
    complex<double> val;
    int mp;

    x = cos(phi);
    if (abs(m) > n)
    {
        // sf_error.error("sph_harm", sf_error.ARG,
        //                "m should not be greater than n")
        return CUDART_NAN;
    }
    if (n < 0)
    {
        // sf_error.error("sph_harm", sf_error.ARG, "n should not be negative")
        return CUDART_NAN;
    }
    if (m < 0)
    {
        mp = -m;
        prefactor = poch(n + mp + 1, -2 * mp);
        if ((mp % 2) == 1)
        {
            prefactor = -prefactor;
        }
    }
    else
    {
        mp = m;
    }
    val = pmv_wrap(mp, n, x);
    if (m < 0)
    {
        val *= prefactor;
    }
    val *= sqrt((2*n + 1) / 4.0 / NPY_PI);
    val *= sqrt(poch(n + m + 1, -2 * m));

    complex<double> exponent(0, m * theta);
    val *= exp(exponent);

    return val;
}
"""
)


sph_harm = _core.create_ufunc(
    "cupyx_scipy_lpmv",
    ("iiff->F", "iidd->D", "llff->F", "lldd->D"),
    "out0 = out0_type(sph_harmonic(in0, in1, in2, in3));",
    preamble=sph_harmonic_definition,
    doc="""Spherical Harmonic.

    .. seealso:: :meth:`scipy.special.sph_harm`

    """,
)