numpy.h 35.9 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <string>
21
#include <initializer_list>
22
#include <functional>
23

Wenzel Jakob's avatar
Wenzel Jakob committed
24
#if defined(_MSC_VER)
25
26
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
27
28
#endif

29
30
31
32
33
34
/* This will be true on all flat address space platforms and allows us to reduce the
   whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");

35
NAMESPACE_BEGIN(pybind11)
36
NAMESPACE_BEGIN(detail)
37
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
Ivan Smirnov's avatar
Ivan Smirnov committed
38
template <typename type> struct is_pod_struct;
Wenzel Jakob's avatar
Wenzel Jakob committed
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

66
67
68
69
struct npy_api {
    enum constants {
        NPY_C_CONTIGUOUS_ = 0x0001,
        NPY_F_CONTIGUOUS_ = 0x0002,
70
        NPY_ARRAY_OWNDATA_ = 0x0004,
71
72
        NPY_ARRAY_FORCECAST_ = 0x0010,
        NPY_ENSURE_ARRAY_ = 0x0040,
73
74
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

92
93
94
95
96
97
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
98
99
100
101
102
103
104
105

    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
106
    PyTypeObject *PyArrayDescr_Type_;
107
108
109
110
111
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
112
    PyObject *(*PyArray_Squeeze_)(PyObject *);
113
114
115
private:
    enum functions {
        API_PyArray_Type = 2,
116
        API_PyArrayDescr_Type = 3,
117
118
119
120
121
122
123
124
        API_PyArray_DescrFromType = 45,
        API_PyArray_FromAny = 69,
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
125
        API_PyArray_Squeeze = 136
126
127
128
129
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
130
        auto c = m.attr("_ARRAY_API");
131
#if PY_MAJOR_VERSION >= 3
132
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
133
#else
134
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
135
#endif
136
        npy_api api;
137
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
138
        DECL_NPY_API(PyArray_Type);
139
        DECL_NPY_API(PyArrayDescr_Type);
140
141
142
143
144
145
146
147
        DECL_NPY_API(PyArray_DescrFromType);
        DECL_NPY_API(PyArray_FromAny);
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
148
        DECL_NPY_API(PyArray_Squeeze);
149
#undef DECL_NPY_API
150
151
152
        return api;
    }
};
153
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
154

155
156
157
158
#define PyArray_GET_(ptr, attr) \
    (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
#define PyArrayDescr_GET_(ptr, attr) \
    (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
159
#define PyArray_FLAGS_(ptr) \
160
    PyArray_GET_(ptr, flags)
161
#define PyArray_CHKFLAGS_(ptr, flag) \
162
    (flag == (PyArray_FLAGS_(ptr) & flag))
163

164
class dtype : public object {
165
public:
166
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
167

168
    dtype(const buffer_info &info) {
169
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
170
171
        m_ptr = descr.strip_padding().release().ptr();
    }
172

173
174
    dtype(std::string format) {
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
175
176
    }

177
178
    dtype(const char *format) : dtype(std::string(format)) { }

179
180
181
182
183
    dtype(list names, list formats, list offsets, size_t itemsize) {
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
184
        args["itemsize"] = pybind11::int_(itemsize);
185
186
187
        m_ptr = from_args(args).release().ptr();
    }

188
189
190
191
192
193
194
    static dtype from_args(object args) {
        // This is essentially the same as calling np.dtype() constructor in Python
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
            pybind11_fail("NumPy: failed to create structured dtype");
        return object(ptr, false);
    }
195

196
    template <typename T> static dtype of() {
197
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
198
    }
199

200
    size_t itemsize() const {
201
        return (size_t) PyArrayDescr_GET_(m_ptr, elsize);
Wenzel Jakob's avatar
Wenzel Jakob committed
202
203
    }

204
    bool has_fields() const {
205
        return PyArrayDescr_GET_(m_ptr, names) != nullptr;
206
207
    }

208
209
    char kind() const {
        return PyArrayDescr_GET_(m_ptr, kind);
210
211
212
    }

private:
213
214
215
216
    static object _dtype_from_pep3118() {
        static PyObject *obj = module::import("numpy.core._internal")
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
        return object(obj, true);
217
    }
218

219
    dtype strip_padding() {
220
221
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
222
        if (!has_fields())
223
            return *this;
224

225
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
226
227
        std::vector<field_descr> field_descriptors;

228
        for (auto field : attr("fields").attr("items")()) {
229
230
            auto spec = object(field, true).cast<tuple>();
            auto name = spec[0].cast<pybind11::str>();
231
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
232
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
233
            if (!len(name) && format.kind() == 'V')
234
                continue;
235
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(), offset});
236
237
238
239
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
240
                      return a.offset.cast<int>() < b.offset.cast<int>();
241
242
243
244
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
245
246
247
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
248
        }
249
        return dtype(names, formats, offsets, itemsize());
250
251
    }
};
252

253
254
255
256
257
258
259
260
261
262
class array : public buffer {
public:
    PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)

    enum {
        c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

263
264
265
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const std::vector<size_t> &strides, const void *ptr = nullptr,
          handle base = handle()) {
266
        auto& api = detail::npy_api::get();
267
268
269
270
        auto ndim = shape.size();
        if (shape.size() != strides.size())
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
271
272
273
274
275
276
277
278
279
280
281
282

        int flags = 0;
        if (base && ptr) {
            array base_array(base, true);
            if (base_array.check())
                /* Copy flags from base (except baseship bit) */
                flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

283
284
        object tmp(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
285
            (Py_intptr_t *) strides.data(), const_cast<void *>(ptr), flags, nullptr), false);
286
287
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
288
289
290
291
292
293
294
        if (ptr) {
            if (base) {
                PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr();
            } else {
                tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
            }
        }
295
296
297
        m_ptr = tmp.release().ptr();
    }

298
299
300
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const void *ptr = nullptr, handle base = handle())
        : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
301

302
303
304
    array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
          handle base = handle())
        : array(dt, std::vector<size_t>{ count }, ptr, base) { }
305
306

    template<typename T> array(const std::vector<size_t>& shape,
307
308
309
                               const std::vector<size_t>& strides,
                               const T* ptr, handle base = handle())
    : array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
310

311
312
313
314
    template <typename T>
    array(const std::vector<size_t> &shape, const T *ptr,
          handle base = handle())
        : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
315

316
317
318
    template <typename T>
    array(size_t count, const T *ptr, handle base = handle())
        : array(std::vector<size_t>{ count }, ptr, base) { }
319
320

    array(const buffer_info &info)
321
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
322

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
        return object(PyArray_GET_(m_ptr, descr), true);
    }

    /// Total number of elements
    size_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
    }

    /// Byte size of a single element
    size_t itemsize() const {
        return (size_t) PyArrayDescr_GET_(PyArray_GET_(m_ptr, descr), elsize);
    }

    /// Total number of bytes
    size_t nbytes() const {
        return size() * itemsize();
    }

    /// Number of dimensions
    size_t ndim() const {
        return (size_t) PyArray_GET_(m_ptr, nd);
    }

348
349
350
351
352
    /// Base object
    object base() const {
        return object(PyArray_GET_(m_ptr, base), true);
    }

353
354
355
356
357
358
359
360
    /// Dimensions of the array
    const size_t* shape() const {
        return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
    }

    /// Dimension along a given axis
    size_t shape(size_t dim) const {
        if (dim >= ndim())
361
            fail_dim_check(dim, "invalid axis");
362
363
364
365
366
367
368
369
370
371
372
        return shape()[dim];
    }

    /// Strides of the array
    const size_t* strides() const {
        return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
    }

    /// Stride along a given axis
    size_t strides(size_t dim) const {
        if (dim >= ndim())
373
            fail_dim_check(dim, "invalid axis");
374
375
376
        return strides()[dim];
    }

377
378
379
380
381
    /// Return the NumPy array flags
    int flags() const {
        return PyArray_FLAGS_(m_ptr);
    }

382
383
384
385
386
387
388
389
390
391
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
        return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
        return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
    }

392
393
394
395
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    template<typename... Ix> const void* data(Ix&&... index) const {
        return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
396
397
    }

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
    template<typename... Ix> void* mutable_data(Ix&&... index) {
        check_writeable();
        return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
    template<typename... Ix> size_t offset_at(Ix&&... index) const {
        if (sizeof...(index) > ndim())
            fail_dim_check(sizeof...(index), "too many indices for an array");
        return get_byte_offset(index...);
    }

    size_t offset_at() const { return 0; }

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
    template<typename... Ix> size_t index_at(Ix&&... index) const {
        return offset_at(index...) / itemsize();
420
421
    }

422
423
424
425
426
427
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
        return array(api.PyArray_Squeeze_(m_ptr), false);
    }

428
429
430
431
432
433
434
    /// Ensure that the argument is a NumPy array
    static array ensure(object input, int ExtraFlags = 0) {
        auto& api = detail::npy_api::get();
        return array(api.PyArray_FromAny_(
            input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr), false);
    }

435
protected:
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    template<typename, typename> friend struct detail::npy_format_descriptor;

    void fail_dim_check(size_t dim, const std::string& msg) const {
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

    template<typename... Ix> size_t get_byte_offset(Ix&&... index) const {
        const size_t idx[] = { (size_t) index... };
        if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{})) {
            auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{});
            throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
                              " is out of bounds for axis " + std::to_string(mismatch.first - idx) +
                              " with size " + std::to_string(*mismatch.second));
        }
        return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0);
    }

    size_t get_byte_offset() const { return 0; }

    void check_writeable() const {
        if (!writeable())
            throw std::runtime_error("array is not writeable");
    }
460
461
462
463
464
465
466
467
468
469
470
471

    static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
        auto ndim = shape.size();
        std::vector<size_t> strides(ndim);
        if (ndim) {
            std::fill(strides.begin(), strides.end(), itemsize);
            for (size_t i = 0; i < ndim - 1; i++)
                for (size_t j = 0; j < ndim - 1 - i; j++)
                    strides[j] *= shape[ndim - 1 - i];
        }
        return strides;
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
472
473
};

474
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
475
public:
476
    PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure_(m_ptr));
477

478
    array_t() : array() { }
479
480
481

    array_t(const buffer_info& info) : array(info) { }

482
483
484
485
    array_t(const std::vector<size_t> &shape,
            const std::vector<size_t> &strides, const T *ptr = nullptr,
            handle base = handle())
        : array(shape, strides, ptr, base) { }
486

487
488
489
    array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
            handle base = handle())
        : array(shape, ptr, base) { }
490

491
492
    array_t(size_t count, const T *ptr = nullptr, handle base = handle())
        : array(count, ptr, base) { }
493

494
495
    constexpr size_t itemsize() const {
        return sizeof(T);
496
497
    }

498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    template<typename... Ix> size_t index_at(Ix&... index) const {
        return offset_at(index...) / itemsize();
    }

    template<typename... Ix> const T* data(Ix&&... index) const {
        return static_cast<const T*>(array::data(index...));
    }

    template<typename... Ix> T* mutable_data(Ix&&... index) {
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
    template<typename... Ix> const T& at(Ix&&... index) const {
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
        // not using offset_at() / index_at() here so as to avoid another dimension check
        return *(static_cast<const T*>(array::data()) + get_byte_offset(index...) / itemsize());
    }

    // Mutable reference to element at a given index
    template<typename... Ix> T& mutable_at(Ix&&... index) {
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
        // not using offset_at() / index_at() here so as to avoid another dimension check
        return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
524
    }
525

Wenzel Jakob's avatar
Wenzel Jakob committed
526
    static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
527

528
    static PyObject *ensure_(PyObject *ptr) {
529
530
        if (ptr == nullptr)
            return nullptr;
531
        auto& api = detail::npy_api::get();
532
        PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
533
                                                detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
534
535
        if (!result)
            PyErr_Clear();
536
537
        Py_DECREF(ptr);
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
538
539
540
    }
};

541
template <typename T>
542
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
543
544
545
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
546
547
548
};

template <size_t N> struct format_descriptor<char[N]> {
549
    static std::string format() { return std::to_string(N) + "s"; }
550
551
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
552
    static std::string format() { return std::to_string(N) + "s"; }
553
554
};

555
NAMESPACE_BEGIN(detail)
Ivan Smirnov's avatar
Ivan Smirnov committed
556
557
558
559
560
561
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };

template <typename T>
struct is_pod_struct {
    enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
562
           !std::is_reference<T>::value &&
Ivan Smirnov's avatar
Ivan Smirnov committed
563
564
565
           !std::is_array<T>::value &&
           !is_std_array<T>::value &&
           !std::is_integral<T>::value &&
566
567
568
569
570
           !std::is_same<typename std::remove_cv<T>::type, float>::value &&
           !std::is_same<typename std::remove_cv<T>::type, double>::value &&
           !std::is_same<typename std::remove_cv<T>::type, bool>::value &&
           !std::is_same<typename std::remove_cv<T>::type, std::complex<float>>::value &&
           !std::is_same<typename std::remove_cv<T>::type, std::complex<double>>::value };
Ivan Smirnov's avatar
Ivan Smirnov committed
571
};
572

573
template <typename T> struct npy_format_descriptor<T, enable_if_t<std::is_integral<T>::value>> {
574
private:
Johan Mabille's avatar
Johan Mabille committed
575
    constexpr static const int values[8] = {
576
577
        npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,  npy_api::NPY_UINT_,  npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
578
public:
579
    enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
580
    static pybind11::dtype dtype() {
581
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
582
583
            return object(ptr, true);
        pybind11_fail("Unsupported buffer format!");
584
    }
585
    template <typename T2 = T, enable_if_t<std::is_signed<T2>::value, int> = 0>
586
    static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
587
    template <typename T2 = T, enable_if_t<!std::is_signed<T2>::value, int> = 0>
588
    static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
589
590
};
template <typename T> constexpr const int npy_format_descriptor<
591
    T, enable_if_t<std::is_integral<T>::value>>::values[8];
592

593
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
594
    enum { value = npy_api::NumPyName }; \
595
    static pybind11::dtype dtype() { \
596
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
597
598
            return object(ptr, true); \
        pybind11_fail("Unsupported buffer format!"); \
599
    } \
600
    static PYBIND11_DESCR name() { return _(Name); } }
601
602
603
604
605
DECL_FMT(float, NPY_FLOAT_, "float32");
DECL_FMT(double, NPY_DOUBLE_, "float64");
DECL_FMT(bool, NPY_BOOL_, "bool");
DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
606
607
#undef DECL_FMT

608
609
#define DECL_CHAR_FMT \
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
610
    static pybind11::dtype dtype() { return std::string("S") + std::to_string(N); }
611
612
613
614
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
#undef DECL_CHAR_FMT

615
616
struct field_descriptor {
    const char *name;
617
    size_t offset;
618
    size_t size;
619
    std::string format;
620
    dtype descr;
621
622
};

623
template <typename T>
624
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
625
    static PYBIND11_DESCR name() { return _("struct"); }
626

627
    static pybind11::dtype dtype() {
628
        if (!dtype_ptr)
629
            pybind11_fail("NumPy: unsupported buffer format!");
630
        return object(dtype_ptr, true);
631
632
    }

633
634
    static std::string format() {
        if (!dtype_ptr)
635
            pybind11_fail("NumPy: unsupported buffer format!");
636
        return format_str;
637
638
639
    }

    static void register_dtype(std::initializer_list<field_descriptor> fields) {
640
        list names, formats, offsets;
641
642
643
        for (auto field : fields) {
            if (!field.descr)
                pybind11_fail("NumPy: unsupported field dtype");
644
            names.append(PYBIND11_STR_TYPE(field.name));
645
            formats.append(field.descr);
646
            offsets.append(pybind11::int_(field.offset));
647
        }
648
        dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
649
650
651
652
653
654
655
656
657
658

        // There is an existing bug in NumPy (as of v1.11): trailing bytes are
        // not encoded explicitly into the format string. This will supposedly
        // get fixed in v1.12; for further details, see these:
        // - https://github.com/numpy/numpy/issues/7797
        // - https://github.com/numpy/numpy/pull/7798
        // Because of this, we won't use numpy's logic to generate buffer format
        // strings and will just do it ourselves.
        std::vector<field_descriptor> ordered_fields(fields);
        std::sort(ordered_fields.begin(), ordered_fields.end(),
659
                  [](const field_descriptor &a, const field_descriptor &b) {
660
661
662
663
664
665
666
667
668
669
670
                      return a.offset < b.offset;
                  });
        size_t offset = 0;
        std::ostringstream oss;
        oss << "T{";
        for (auto& field : ordered_fields) {
            if (field.offset > offset)
                oss << (field.offset - offset) << 'x';
            // note that '=' is required to cover the case of unaligned fields
            oss << '=' << field.format << ':' << field.name << ':';
            offset = field.offset + field.size;
Ivan Smirnov's avatar
Ivan Smirnov committed
671
        }
672
673
674
        if (sizeof(T) > offset)
            oss << (sizeof(T) - offset) << 'x';
        oss << '}';
675
        format_str = oss.str();
676
677

        // Sanity check: verify that NumPy properly parses our buffer format string
678
        auto& api = npy_api::get();
679
680
        auto arr =  array(buffer_info(nullptr, sizeof(T), format(), 1));
        if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
681
            pybind11_fail("NumPy: invalid buffer descriptor!");
682
683
684
    }

private:
685
686
    static std::string format_str;
    static PyObject* dtype_ptr;
687
688
};

689
template <typename T>
690
std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
691
template <typename T>
692
PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;
693

694
// Extract name, offset and format descriptor for a struct field
695
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
696
    ::pybind11::detail::field_descriptor { \
697
698
        #Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
        ::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
699
        ::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
700
    }
701
702
703

// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
704
705
706
707
708
709
710
711
712
713
714
715
716
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
717
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
718
719
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
720
#else
721
722
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
723
#endif
724
725
726
727
728
729
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
730
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
731
732
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
733

734
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
735
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
736
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
737

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
758

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

782
template <size_t N> class multi_array_iterator {
783
784
785
public:
    using container_type = std::vector<size_t>;

786
787
788
789
790
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

791
        // Manual copy to avoid conversion warning if using std::copy
792
        for (size_t i = 0; i < shape.size(); ++i)
793
794
795
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
796
        for (size_t i = 0; i < N; ++i)
797
798
799
800
801
802
803
804
805
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
806
            } else {
807
808
809
810
811
812
                m_index[i] = 0;
            }
        }
        return *this;
    }

813
    template <size_t K, class T> const T& data() const {
814
815
816
817
818
819
820
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

821
822
823
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
824
825
826
827
828
829
830
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
831
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
832
833
834
835
836
837
838
839
840
841
842
843
844
845
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
846
        for (auto &iter : m_common_iterator)
847
848
849
850
851
852
853
854
855
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

template <size_t N>
856
857
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
858
859
860
        return std::max(res, buf.ndim);
    });

861
    shape = std::vector<size_t>(ndim, 1);
862
863
864
865
    bool trivial_broadcast = true;
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
866
867
868
869
        for (auto shape_iter = buffers[i].shape.rbegin();
             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {

            if (*res_iter == 1)
870
                *res_iter = *shape_iter;
871
            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
872
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
873

874
875
876
877
878
879
880
            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
        }
        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
    }
    return trivial_broadcast;
}

881
882
883
884
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

885
886
    template <typename T>
    vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
887

888
    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
889
890
        return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
891

892
    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
893
        /* Request buffers from all parameters */
894
        const size_t N = sizeof...(Args);
895

Wenzel Jakob's avatar
Wenzel Jakob committed
896
897
898
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
899
        size_t ndim = 0;
900
901
        std::vector<size_t> shape(0);
        bool trivial_broadcast = broadcast(buffers, ndim, shape);
902

903
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
904
905
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
906
            strides[ndim-1] = sizeof(Return);
907
            for (size_t i = ndim - 1; i > 0; --i) {
908
909
910
911
                strides[i - 1] = strides[i] * shape[i];
                size *= shape[i];
            }
            size *= shape[0];
Wenzel Jakob's avatar
Wenzel Jakob committed
912
913
        }

914
        if (size == 1)
915
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
916

917
918
919
        array_t<Return> result(shape, strides);
        auto buf = result.request();
        auto output = (Return *) buf.ptr;
920

921
        if (trivial_broadcast) {
922
            /* Call the function */
923
            for (size_t i = 0; i < size; ++i) {
924
                output[i] = f((buffers[Index].size == 1
925
926
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);
927
            }
928
        } else {
929
930
            apply_broadcast<N, Index...>(buffers, buf, index);
        }
931
932

        return result;
933
    }
934
935

    template <size_t N, size_t... Index>
936
937
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
938
939
940
941
942
943
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

944
945
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
946
947
948
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
949
950
};

951
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
952
    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); }
953
954
};

955
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
956

957
958
959
template <typename Func, typename Return, typename... Args>
detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
960
961
}

962
963
964
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
965
966
}

967
968
969
970
971
template <typename Func>
auto vectorize(Func &&f) -> decltype(
        vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type *) nullptr)) {
    return vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(
                   &std::remove_reference<Func>::type::operator())>::type *) nullptr);
Wenzel Jakob's avatar
Wenzel Jakob committed
972
973
}

974
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
975
976
977
978

#if defined(_MSC_VER)
#pragma warning(pop)
#endif