_zeta.py 2.85 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# This source code contains SciPy's code.
# https://github.com/scipy/scipy/blob/master/scipy/special/cephes/zeta.c
#
#
# Cephes Math Library Release 2.0:  April, 1987
# Copyright 1984, 1987 by Stephen L. Moshier
# Direct inquiries to 30 Frost Street, Cambridge, MA 02140

# TODO(YoshikawaMasashi): float implementation of zeta function

from cupy import _core


zeta_definition = '''
/* Expansion coefficients
 * for Euler-Maclaurin summation formula
 * (2k)! / B2k
 * where B2k are Bernoulli numbers
 */
__constant__ double A[] = {
    12.0,
    -720.0,
    30240.0,
    -1209600.0,
    47900160.0,
    -1.8924375803183791606e9,	/*1.307674368e12/691 */
    7.47242496e10,
    -2.950130727918164224e12,	/*1.067062284288e16/3617 */
    1.1646782814350067249e14,	/*5.109094217170944e18/43867 */
    -4.5979787224074726105e15,	/*8.028576626982912e20/174611 */
    1.8152105401943546773e17,	/*1.5511210043330985984e23/854513 */
    -7.1661652561756670113e18	/*1.6938241367317436694528e27/236364091 */
};

__constant__ double MACHEP = 1.11022302462515654042E-16;

/* 30 Nov 86 -- error in third coefficient fixed */


double __device__ zeta(double x, double q)
{
    int i;
    double a, b, k, s, t, w;

    if (x == 1.0){
        return 1.0/0.0;
    }

    if (x < 1.0) {
        return nan("");
    }

    if (q <= 0.0) {
        if (q == floor(q)) {
            return 1.0/0.0;
        }
        if (x != floor(x)){
            return nan("");	/* because q^-x not defined */
        }
    }

    /* Asymptotic expansion
     * http://dlmf.nist.gov/25.11#E43
     */
    if (q > 1e8) {
        return (1/(x - 1) + 1/(2*q)) * pow(q, 1 - x);
    }

    /* Euler-Maclaurin summation formula */

    /* Permit negative q but continue sum until n+q > +9 .
     * This case should be handled by a reflection formula.
     * If q<0 and x is an integer, there is a relation to
     * the polyGamma function.
     */
    s = pow(q, -x);
    a = q;
    i = 0;
    b = 0.0;
    while ((i < 9) || (a <= 9.0)) {
        i += 1;
        a += 1.0;
        b = pow(a, -x);
        s += b;
        if (fabs(b / s) < MACHEP){
            return s;
        }
    }

    w = a;
    s += b * w / (x - 1.0);
    s -= 0.5 * b;
    a = 1.0;
    k = 0.0;
    for (i = 0; i < 12; i++) {
        a *= x + k;
        b /= w;
        t = a * b / A[i];
        s = s + t;
        t = fabs(t / s);
        if (t < MACHEP){
            return s;
        }
        k += 1.0;
        a *= x + k;
        b /= w;
        k += 1.0;
    }
    return s;
}
'''


zeta = _core.create_ufunc(
    'cupyx_scipy_special_zeta', ('ff->f', 'dd->d'),
    'out0 = zeta(in0, in1)',
    preamble=zeta_definition,
    doc="""Hurwitz zeta function.

    Args:
        x (cupy.ndarray): Input data, must be real.
        q (cupy.ndarray): Input data, must be real.

    Returns:
        cupy.ndarray: Values of zeta(x, q).

    .. seealso:: :data:`scipy.special.zeta`

    """)