numpy.h 8.96 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, auto-vectorization support
Wenzel Jakob's avatar
Wenzel Jakob committed
3
4
5
6
7
8
9
10
11

    Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14

Wenzel Jakob's avatar
Wenzel Jakob committed
15
16
17
18
19
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif

20
NAMESPACE_BEGIN(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
21

Wenzel Jakob's avatar
Wenzel Jakob committed
22
23
template <typename type> struct npy_format_descriptor { };

Wenzel Jakob's avatar
Wenzel Jakob committed
24
class array : public buffer {
Wenzel Jakob's avatar
Wenzel Jakob committed
25
public:
Wenzel Jakob's avatar
Wenzel Jakob committed
26
27
28
29
30
31
32
    struct API {
        enum Entries {
            API_PyArray_Type = 2,
            API_PyArray_DescrFromType = 45,
            API_PyArray_FromAny = 69,
            API_PyArray_NewCopy = 85,
            API_PyArray_NewFromDescr = 94,
33
34
35
36
37
38
39
40
41
42
43
44
45

            NPY_C_CONTIGUOUS_ = 0x0001,
            NPY_F_CONTIGUOUS_ = 0x0002,
            NPY_ARRAY_FORCECAST_ = 0x0010,
            NPY_ENSURE_ARRAY_ = 0x0040,
            NPY_BOOL_ = 0,
            NPY_BYTE_, NPY_UBYTE_,
            NPY_SHORT_, NPY_USHORT_,
            NPY_INT_, NPY_UINT_,
            NPY_LONG_, NPY_ULONG_,
            NPY_LONGLONG_, NPY_ULONGLONG_,
            NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
            NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
Wenzel Jakob's avatar
Wenzel Jakob committed
46
47
48
        };

        static API lookup() {
49
50
            module m = module::import("numpy.core.multiarray");
            object c = (object) m.attr("_ARRAY_API");
51
#if PY_MAJOR_VERSION >= 3
52
            void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
53
#else
54
            void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
55
#endif
Wenzel Jakob's avatar
Wenzel Jakob committed
56
            API api;
57
58
59
60
61
            api.PyArray_Type_          = (decltype(api.PyArray_Type_))          api_ptr[API_PyArray_Type];
            api.PyArray_DescrFromType_ = (decltype(api.PyArray_DescrFromType_)) api_ptr[API_PyArray_DescrFromType];
            api.PyArray_FromAny_       = (decltype(api.PyArray_FromAny_))       api_ptr[API_PyArray_FromAny];
            api.PyArray_NewCopy_       = (decltype(api.PyArray_NewCopy_))       api_ptr[API_PyArray_NewCopy];
            api.PyArray_NewFromDescr_  = (decltype(api.PyArray_NewFromDescr_))  api_ptr[API_PyArray_NewFromDescr];
Wenzel Jakob's avatar
Wenzel Jakob committed
62
63
64
            return api;
        }

65
        bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
Wenzel Jakob's avatar
Wenzel Jakob committed
66

67
68
        PyObject *(*PyArray_DescrFromType_)(int);
        PyObject *(*PyArray_NewFromDescr_)
Wenzel Jakob's avatar
Wenzel Jakob committed
69
70
            (PyTypeObject *, PyObject *, int, Py_intptr_t *,
             Py_intptr_t *, void *, int, PyObject *);
71
72
73
        PyObject *(*PyArray_NewCopy_)(PyObject *, int);
        PyTypeObject *PyArray_Type_;
        PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
Wenzel Jakob's avatar
Wenzel Jakob committed
74
    };
Wenzel Jakob's avatar
Wenzel Jakob committed
75

76
    PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
Wenzel Jakob's avatar
Wenzel Jakob committed
77
78
79

    template <typename Type> array(size_t size, const Type *ptr) {
        API& api = lookup_api();
80
        PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor<Type>::value);
Wenzel Jakob's avatar
Wenzel Jakob committed
81
82
83
        if (descr == nullptr)
            throw std::runtime_error("NumPy: unsupported buffer format!");
        Py_intptr_t shape = (Py_intptr_t) size;
84
85
86
87
88
        object tmp = object(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
        if (ptr && tmp)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
        if (!tmp)
Wenzel Jakob's avatar
Wenzel Jakob committed
89
            throw std::runtime_error("NumPy: unable to create array!");
90
        m_ptr = tmp.release();
Wenzel Jakob's avatar
Wenzel Jakob committed
91
92
93
94
    }

    array(const buffer_info &info) {
        API& api = lookup_api();
95
        if ((info.format.size() < 1) || (info.format.size() > 2))
Wenzel Jakob's avatar
Wenzel Jakob committed
96
            throw std::runtime_error("Unsupported buffer format!");
Wenzel Jakob's avatar
Wenzel Jakob committed
97
        int fmt = (int) info.format[0];
98
99
100
101
        if (info.format == "Zd")      fmt = API::NPY_CDOUBLE_;
        else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;

        PyObject *descr = api.PyArray_DescrFromType_(fmt);
Wenzel Jakob's avatar
Wenzel Jakob committed
102
103
        if (descr == nullptr)
            throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
104
        object tmp(api.PyArray_NewFromDescr_(
105
            api.PyArray_Type_, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
106
107
108
109
            (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
        if (info.ptr && tmp)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
        if (!tmp)
Wenzel Jakob's avatar
Wenzel Jakob committed
110
            throw std::runtime_error("NumPy: unable to create array!");
111
        m_ptr = tmp.release();
Wenzel Jakob's avatar
Wenzel Jakob committed
112
113
114
115
116
117
118
119
120
    }

protected:
    static API &lookup_api() {
        static API api = API::lookup();
        return api;
    }
};

121
template <typename T> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
122
public:
123
    PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
124
    array_t() : array() { }
Wenzel Jakob's avatar
Wenzel Jakob committed
125
    static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
126
    static PyObject *ensure(PyObject *ptr) {
127
128
        if (ptr == nullptr)
            return nullptr;
Wenzel Jakob's avatar
Wenzel Jakob committed
129
        API &api = lookup_api();
130
131
132
133
        PyObject *descr = api.PyArray_DescrFromType_(npy_format_descriptor<T>::value);
        return api.PyArray_FromAny_(ptr, descr, 0, 0,
                                    API::NPY_C_CONTIGUOUS_ | API::NPY_ENSURE_ARRAY_ |
                                    API::NPY_ARRAY_FORCECAST_, nullptr);
Wenzel Jakob's avatar
Wenzel Jakob committed
134
135
136
    }
};

137
#define DECL_FMT(t, n) template<> struct npy_format_descriptor<t> { enum { value = array::API::n }; }
138
139
140
141
142
DECL_FMT(int8_t, NPY_BYTE_);  DECL_FMT(uint8_t, NPY_UBYTE_); DECL_FMT(int16_t, NPY_SHORT_);
DECL_FMT(uint16_t, NPY_USHORT_); DECL_FMT(int32_t, NPY_INT_); DECL_FMT(uint32_t, NPY_UINT_);
DECL_FMT(int64_t, NPY_LONGLONG_); DECL_FMT(uint64_t, NPY_ULONGLONG_); DECL_FMT(float, NPY_FLOAT_);
DECL_FMT(double, NPY_DOUBLE_); DECL_FMT(bool, NPY_BOOL_); DECL_FMT(std::complex<float>, NPY_CFLOAT_);
DECL_FMT(std::complex<double>, NPY_CDOUBLE_);
143
144
#undef DECL_FMT

Wenzel Jakob's avatar
Wenzel Jakob committed
145
146
NAMESPACE_BEGIN(detail)

147
148
149
150
template <typename T> struct handle_type_name<array_t<T>> {
    static PYBIND11_DESCR name() { return _("array[") + type_caster<T>::name() + _("]"); }
};

151
152
153
154
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

155
156
    template <typename T>
    vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
157

158
    object operator()(array_t<Args>... args) {
159
160
        return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
161

162
    template <size_t ... Index> object run(array_t<Args>&... args, index_sequence<Index...>) {
Wenzel Jakob's avatar
Wenzel Jakob committed
163
        /* Request buffers from all parameters */
164
        const size_t N = sizeof...(Args);
Wenzel Jakob's avatar
Wenzel Jakob committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
        int ndim = 0; size_t count = 0;
        std::vector<size_t> shape;
        for (size_t i=0; i<N; ++i) {
            if (buffers[i].count > count) {
                ndim = buffers[i].ndim;
                shape = buffers[i].shape;
                count = buffers[i].count;
            }
        }
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
179
            strides[ndim-1] = sizeof(Return);
Wenzel Jakob's avatar
Wenzel Jakob committed
180
181
182
183
184
            for (int i=ndim-1; i>0; --i)
                strides[i-1] = strides[i] * shape[i];
        }

        /* Check if the parameters are actually compatible */
185
        for (size_t i=0; i<N; ++i)
Wenzel Jakob's avatar
Wenzel Jakob committed
186
            if (buffers[i].count != 1 && (buffers[i].ndim != ndim || buffers[i].shape != shape))
187
                throw std::runtime_error("pybind11::vectorize: incompatible size/dimension of inputs!");
Wenzel Jakob's avatar
Wenzel Jakob committed
188
189

        if (count == 1)
190
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
191

192
        array result(buffer_info(nullptr, sizeof(Return),
193
            format_descriptor<Return>::value(),
Wenzel Jakob's avatar
Wenzel Jakob committed
194
            ndim, shape, strides));
195
196
197
198
199
200
201
202
203
204
205

        buffer_info buf = result.request();
        Return *output = (Return *) buf.ptr;

        /* Call the function */
        for (size_t i=0; i<count; ++i)
            output[i] = f((buffers[Index].count == 1
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);

        return result;
206
207
208
209
    }
};

NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
210

211
212
213
template <typename Func, typename Return, typename... Args>
detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
214
215
}

216
217
218
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
219
220
221
222
223
224
225
226
}

template <typename func> auto vectorize(func &&f) -> decltype(
        vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
    return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
                   &std::remove_reference<func>::type::operator())>::type *) nullptr);
}

227
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
228
229
230
231

#if defined(_MSC_VER)
#pragma warning(pop)
#endif