numpy.h 54.1 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <string>
21
#include <initializer_list>
22
#include <functional>
23
#include <utility>
24
#include <typeindex>
25

Wenzel Jakob's avatar
Wenzel Jakob committed
26
#if defined(_MSC_VER)
27
28
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
29
30
#endif

31
32
33
34
35
36
/* This will be true on all flat address space platforms and allows us to reduce the
   whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");

37
NAMESPACE_BEGIN(pybind11)
38
39
40

class array; // Forward declaration

41
NAMESPACE_BEGIN(detail)
42
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

70
71
72
73
74
75
76
77
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

78
79
80
81
82
83
84
85
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

86
87
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
88
89
90
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
91
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
92
93
        return nullptr;
    }
94
95
96
97

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
98
99
};

Ivan Smirnov's avatar
Ivan Smirnov committed
100
101
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
102
103
104
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
105
106
107
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
108
109
110
    return *ptr;
}

111
112
struct npy_api {
    enum constants {
113
114
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
115
        NPY_ARRAY_OWNDATA_ = 0x0004,
116
        NPY_ARRAY_FORCECAST_ = 0x0010,
117
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
118
119
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

137
138
139
140
141
142
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
143
144
145
146
147
148
149
150

    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
151
    PyTypeObject *PyVoidArrType_Type_;
152
    PyTypeObject *PyArrayDescr_Type_;
153
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
154
155
156
157
158
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
159
    PyObject *(*PyArray_Squeeze_)(PyObject *);
Jason Rhinelander's avatar
Jason Rhinelander committed
160
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
161
162
163
private:
    enum functions {
        API_PyArray_Type = 2,
164
        API_PyArrayDescr_Type = 3,
165
        API_PyVoidArrType_Type = 39,
166
        API_PyArray_DescrFromType = 45,
167
        API_PyArray_DescrFromScalar = 57,
168
169
170
171
172
173
174
        API_PyArray_FromAny = 69,
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
175
176
        API_PyArray_Squeeze = 136,
        API_PyArray_SetBaseObject = 282
177
178
179
180
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
181
        auto c = m.attr("_ARRAY_API");
182
#if PY_MAJOR_VERSION >= 3
183
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
184
#else
185
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
186
#endif
187
        npy_api api;
188
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
189
        DECL_NPY_API(PyArray_Type);
190
        DECL_NPY_API(PyVoidArrType_Type);
191
        DECL_NPY_API(PyArrayDescr_Type);
192
        DECL_NPY_API(PyArray_DescrFromType);
193
        DECL_NPY_API(PyArray_DescrFromScalar);
194
195
196
197
198
199
200
        DECL_NPY_API(PyArray_FromAny);
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
201
        DECL_NPY_API(PyArray_Squeeze);
Jason Rhinelander's avatar
Jason Rhinelander committed
202
        DECL_NPY_API(PyArray_SetBaseObject);
203
#undef DECL_NPY_API
204
205
206
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

228
229
230
231
232
233
234
235
236
237
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

template <typename T> using is_pod_struct = all_of<
    std::is_pod<T>, // since we're accessing directly in memory we need a POD type
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

238
239
240
241
242
243
244
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
template <size_t Dim = 0, typename Strides, typename... Ix>
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
    return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
}

/** Proxy class providing unsafe, unchecked const access to array data.  This is constructed through
245
246
 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.  `Dims`
 * will be -1 for dimensions determined at runtime.
247
 */
248
template <typename T, ssize_t Dims>
249
250
class unchecked_reference {
protected:
251
    static constexpr bool Dynamic = Dims < 0;
252
253
    const unsigned char *data_;
    // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
254
255
256
257
    // make large performance gains on big, nested loops, but requires compile-time dimensions
    conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>>
        shape_, strides_;
    const size_t dims_;
258
259

    friend class pybind11::array;
260
261
262
263
264
    // Constructor for compile-time dimensions:
    template <bool Dyn = Dynamic>
    unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>)
    : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
        for (size_t i = 0; i < dims_; i++) {
265
266
267
268
            shape_[i] = shape[i];
            strides_[i] = strides[i];
        }
    }
269
270
271
272
    // Constructor for runtime dimensions:
    template <bool Dyn = Dynamic>
    unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims)
    : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
273
274

public:
275
276
277
    /** Unchecked const reference access to data at the given indices.  For a compile-time known
     * number of dimensions, this requires the correct number of arguments; for run-time
     * dimensionality, this is not checked (and so is up to the caller to use safely).
278
     */
279
280
281
282
    template <typename... Ix> const T &operator()(Ix... index) const {
        static_assert(sizeof...(Ix) == Dims || Dynamic,
                "Invalid number of indices for unchecked array reference");
        return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...));
283
284
285
286
    }
    /** Unchecked const reference access to data; this operator only participates if the reference
     * is to a 1-dimensional array.  When present, this is exactly equivalent to `obj(index)`.
     */
287
    template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
288
289
    const T &operator[](size_t index) const { return operator()(index); }

290
291
292
293
294
295
    /// Pointer access to the data at the given indices.
    template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); }

    /// Returns the item size, i.e. sizeof(T)
    constexpr static size_t itemsize() { return sizeof(T); }

296
297
298
299
    /// Returns the shape (i.e. size) of dimension `dim`
    size_t shape(size_t dim) const { return shape_[dim]; }

    /// Returns the number of dimensions of the array
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    size_t ndim() const { return dims_; }

    /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
    template <bool Dyn = Dynamic>
    enable_if_t<!Dyn, size_t> size() const {
        return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>());
    }
    template <bool Dyn = Dynamic>
    enable_if_t<Dyn, size_t> size() const {
        return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>());
    }

    /// Returns the total number of bytes used by the referenced data.  Note that the actual span in
    /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
    size_t nbytes() const {
        return size() * itemsize();
    }
317
318
};

319
template <typename T, ssize_t Dims>
320
321
322
323
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
    friend class pybind11::array;
    using ConstBase = unchecked_reference<T, Dims>;
    using ConstBase::ConstBase;
324
    using ConstBase::Dynamic;
325
326
327
public:
    /// Mutable, unchecked access to data at the given indices.
    template <typename... Ix> T& operator()(Ix... index) {
328
329
        static_assert(sizeof...(Ix) == Dims || Dynamic,
                "Invalid number of indices for unchecked array reference");
330
331
332
        return const_cast<T &>(ConstBase::operator()(index...));
    }
    /** Mutable, unchecked access data at the given index; this operator only participates if the
333
334
     * reference is to a 1-dimensional array (or has runtime dimensions).  When present, this is
     * exactly equivalent to `obj(index)`.
335
     */
336
    template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
337
    T &operator[](size_t index) { return operator()(index); }
338
339
340

    /// Mutable pointer access to the data at the given indices.
    template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); }
341
342
};

343
template <typename T, ssize_t Dim>
344
struct type_caster<unchecked_reference<T, Dim>> {
345
    static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
346
};
347
template <typename T, ssize_t Dim>
348
349
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};

350
NAMESPACE_END(detail)
351

352
class dtype : public object {
353
public:
354
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
355

356
    explicit dtype(const buffer_info &info) {
357
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
358
359
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
360
    }
361

362
    explicit dtype(const std::string &format) {
363
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
364
365
    }

366
    dtype(const char *format) : dtype(std::string(format)) { }
367

368
369
370
371
372
    dtype(list names, list formats, list offsets, size_t itemsize) {
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
373
        args["itemsize"] = pybind11::int_(itemsize);
374
375
376
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
377
    /// This is essentially the same as calling numpy.dtype(args) in Python.
378
379
380
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
381
            throw error_already_set();
382
        return reinterpret_steal<dtype>(ptr);
383
    }
384

Ivan Smirnov's avatar
Ivan Smirnov committed
385
    /// Return dtype associated with a C++ type.
386
    template <typename T> static dtype of() {
387
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
388
    }
389

Ivan Smirnov's avatar
Ivan Smirnov committed
390
    /// Size of the data type in bytes.
391
    size_t itemsize() const {
392
        return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
393
394
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
395
    /// Returns true for structured data types.
396
    bool has_fields() const {
397
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
398
399
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
400
    /// Single-character type code.
401
    char kind() const {
402
        return detail::array_descriptor_proxy(m_ptr)->kind;
403
404
405
    }

private:
406
407
408
    static object _dtype_from_pep3118() {
        static PyObject *obj = module::import("numpy.core._internal")
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
409
        return reinterpret_borrow<object>(obj);
410
    }
411

412
    dtype strip_padding(size_t itemsize) {
413
414
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
415
        if (!has_fields())
416
            return *this;
417

418
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
419
420
        std::vector<field_descr> field_descriptors;

421
        for (auto field : attr("fields").attr("items")()) {
422
            auto spec = field.cast<tuple>();
423
            auto name = spec[0].cast<pybind11::str>();
424
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
425
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
426
            if (!len(name) && format.kind() == 'V')
427
                continue;
428
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
429
430
431
432
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
433
                      return a.offset.cast<int>() < b.offset.cast<int>();
434
435
436
437
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
438
439
440
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
441
        }
442
        return dtype(names, formats, offsets, itemsize);
443
444
    }
};
445

446
447
class array : public buffer {
public:
448
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
449
450

    enum {
451
452
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
453
454
455
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

456
457
    array() : array(0, static_cast<const double *>(nullptr)) {}

458
459
460
461
462
463
464
465
    using ShapeContainer = detail::any_container<Py_intptr_t>;
    using StridesContainer = detail::any_container<Py_intptr_t>;

    // Constructs an array taking shape/strides from arbitrary container types
    array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
          const void *ptr = nullptr, handle base = handle()) {

        if (strides->empty())
466
            *strides = default_strides(*shape, dt.itemsize());
467
468
469

        auto ndim = shape->size();
        if (ndim != strides->size())
470
471
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
472
473
474

        int flags = 0;
        if (base && ptr) {
475
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
476
                /* Copy flags from base (except ownership bit) */
477
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
478
479
480
481
482
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

483
        auto &api = detail::npy_api::get();
484
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
485
            api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
486
            const_cast<void *>(ptr), flags, nullptr));
487
488
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
489
490
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
491
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
492
            } else {
493
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
494
495
            }
        }
496
497
498
        m_ptr = tmp.release().ptr();
    }

499
500
    array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
        : array(dt, std::move(shape), {}, ptr, base) { }
501

502
503
504
505
506
507
    // This constructor is only needed to avoid ambiguity with the deprecated (handle, bool)
    // constructor that comes from PYBIND11_OBJECT_CVT; once that is gone, the above constructor can
    // handle it (because ShapeContainer is implicitly constructible from arithmetic types)
    template <typename T, typename = detail::enable_if_t<std::is_arithmetic<T>::value && !std::is_same<bool, T>::value>>
    array(const pybind11::dtype &dt, T count)
        : array(dt, count, nullptr) { }
508

509
510
511
    template <typename T>
    array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
        : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
512

513
    template <typename T>
514
515
    array(ShapeContainer shape, const T *ptr, handle base = handle())
        : array(std::move(shape), {}, ptr, base) { }
516

517
    explicit array(const buffer_info &info)
518
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
519

520
521
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
522
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
523
524
525
526
527
528
529
530
531
    }

    /// Total number of elements
    size_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
    }

    /// Byte size of a single element
    size_t itemsize() const {
532
        return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
533
534
535
536
537
538
539
540
541
    }

    /// Total number of bytes
    size_t nbytes() const {
        return size() * itemsize();
    }

    /// Number of dimensions
    size_t ndim() const {
542
        return (size_t) detail::array_proxy(m_ptr)->nd;
543
544
    }

545
546
    /// Base object
    object base() const {
547
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
548
549
    }

550
551
    /// Dimensions of the array
    const size_t* shape() const {
552
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
553
554
555
556
557
    }

    /// Dimension along a given axis
    size_t shape(size_t dim) const {
        if (dim >= ndim())
558
            fail_dim_check(dim, "invalid axis");
559
560
561
562
563
        return shape()[dim];
    }

    /// Strides of the array
    const size_t* strides() const {
564
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
565
566
567
568
569
    }

    /// Stride along a given axis
    size_t strides(size_t dim) const {
        if (dim >= ndim())
570
            fail_dim_check(dim, "invalid axis");
571
572
573
        return strides()[dim];
    }

574
575
    /// Return the NumPy array flags
    int flags() const {
576
        return detail::array_proxy(m_ptr)->flags;
577
578
    }

579
580
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
581
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
582
583
584
585
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
586
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
587
588
    }

589
590
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
591
    template<typename... Ix> const void* data(Ix... index) const {
592
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
593
594
    }

595
596
597
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
598
    template<typename... Ix> void* mutable_data(Ix... index) {
599
        check_writeable();
600
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
601
602
603
604
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
605
    template<typename... Ix> size_t offset_at(Ix... index) const {
606
607
        if (sizeof...(index) > ndim())
            fail_dim_check(sizeof...(index), "too many indices for an array");
608
        return byte_offset(size_t(index)...);
609
610
611
612
613
614
    }

    size_t offset_at() const { return 0; }

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
615
    template<typename... Ix> size_t index_at(Ix... index) const {
616
        return offset_at(index...) / itemsize();
617
618
    }

619
620
621
622
623
    /** Returns a proxy object that provides access to the array's data without bounds or
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
624
625
    template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
        if (Dims >= 0 && ndim() != (size_t) Dims)
626
627
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
628
        return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
629
630
631
632
633
634
635
636
    }

    /** Returns a proxy object that provides const access to the array's data without bounds or
     * dimensionality checking.  Unlike `mutable_unchecked()`, this does not require that the
     * underlying array have the `writable` flag.  Use with care: the array must not be destroyed or
     * reshaped for the duration of the returned object, and the caller must take care not to access
     * invalid dimensions or dimension indices.
     */
637
638
    template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
        if (Dims >= 0 && ndim() != (size_t) Dims)
639
640
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
641
        return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
642
643
    }

644
645
646
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
647
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
648
649
    }

650
    /// Ensure that the argument is a NumPy array
651
652
653
654
655
656
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
657
658
    }

659
protected:
660
661
662
663
664
665
666
    template<typename, typename> friend struct detail::npy_format_descriptor;

    void fail_dim_check(size_t dim, const std::string& msg) const {
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

667
668
    template<typename... Ix> size_t byte_offset(Ix... index) const {
        check_dimensions(index...);
669
        return detail::byte_offset_unsafe(strides(), size_t(index)...);
670
671
    }

672
673
    void check_writeable() const {
        if (!writeable())
674
            throw std::domain_error("array is not writeable");
675
    }
676

677
    static std::vector<Py_intptr_t> default_strides(const std::vector<Py_intptr_t>& shape, size_t itemsize) {
678
        auto ndim = shape.size();
679
        std::vector<Py_intptr_t> strides(ndim);
680
681
682
683
684
685
686
687
        if (ndim) {
            std::fill(strides.begin(), strides.end(), itemsize);
            for (size_t i = 0; i < ndim - 1; i++)
                for (size_t j = 0; j < ndim - 1 - i; j++)
                    strides[j] *= shape[ndim - 1 - i];
        }
        return strides;
    }
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

    template<typename... Ix> void check_dimensions(Ix... index) const {
        check_dimensions_impl(size_t(0), shape(), size_t(index)...);
    }

    void check_dimensions_impl(size_t, const size_t*) const { }

    template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
703
704
705

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
706
707
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
708
            return nullptr;
709
        }
710
        return detail::npy_api::get().PyArray_FromAny_(
711
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
712
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
713
714
};

715
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
716
public:
717
718
    using value_type = T;

719
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
720
721
    array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
    array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
722

723
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
724
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
725
726
727
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
728

729
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
730
731
        if (!m_ptr) throw error_already_set();
    }
732

733
    explicit array_t(const buffer_info& info) : array(info) { }
734

735
736
    array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), std::move(strides), ptr, base) { }
737

738
739
    explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), ptr, base) { }
740

741
742
    constexpr size_t itemsize() const {
        return sizeof(T);
743
744
    }

745
    template<typename... Ix> size_t index_at(Ix... index) const {
746
747
748
        return offset_at(index...) / itemsize();
    }

749
    template<typename... Ix> const T* data(Ix... index) const {
750
751
752
        return static_cast<const T*>(array::data(index...));
    }

753
    template<typename... Ix> T* mutable_data(Ix... index) {
754
755
756
757
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
758
    template<typename... Ix> const T& at(Ix... index) const {
759
760
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
761
        return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
762
763
764
    }

    // Mutable reference to element at a given index
765
    template<typename... Ix> T& mutable_at(Ix... index) {
766
767
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
768
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
769
    }
770

771
772
773
774
775
    /** Returns a proxy object that provides access to the array's data without bounds or
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
776
    template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
777
778
779
780
781
782
783
784
785
        return array::mutable_unchecked<T, Dims>();
    }

    /** Returns a proxy object that provides const access to the array's data without bounds or
     * dimensionality checking.  Unlike `unchecked()`, this does not require that the underlying
     * array have the `writable` flag.  Use with care: the array must not be destroyed or reshaped
     * for the duration of the returned object, and the caller must take care not to access invalid
     * dimensions or dimension indices.
     */
786
    template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
787
788
789
        return array::unchecked<T, Dims>();
    }

Jason Rhinelander's avatar
Jason Rhinelander committed
790
791
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
792
793
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
794
795
        if (!result)
            PyErr_Clear();
796
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
797
    }
798

Wenzel Jakob's avatar
Wenzel Jakob committed
799
    static bool check_(handle h) {
800
801
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
802
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
803
804
805
806
807
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
808
809
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
810
            return nullptr;
811
        }
812
813
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
814
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
815
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
816
817
};

818
template <typename T>
819
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
820
821
822
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
823
824
825
};

template <size_t N> struct format_descriptor<char[N]> {
826
    static std::string format() { return std::to_string(N) + "s"; }
827
828
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
829
    static std::string format() { return std::to_string(N) + "s"; }
830
831
};

832
833
834
835
836
837
838
839
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

840
NAMESPACE_BEGIN(detail)
841
842
843
844
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

845
846
847
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
848
        value = type::ensure(src);
849
850
851
852
853
854
855
856
857
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

858
859
860
861
862
863
864
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
    static bool compare(const buffer_info& b) {
        return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
    }
};

865
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
866
private:
867
868
869
870
871
872
873
874
875
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,    npy_api::NPY_UINT_,    npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

876
public:
877
878
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

879
    static pybind11::dtype dtype() {
880
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
881
            return reinterpret_borrow<pybind11::dtype>(ptr);
882
        pybind11_fail("Unsupported buffer format!");
883
    }
884
885
886
887
888
889
890
891
892
893
894
895
896
    template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, bool>::value>(_("bool"),
            _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
    }
    template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
                _("float") + _<sizeof(T)*8>(), _("longdouble"));
    }
    template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
897
                _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
898
    }
899
};
900
901

#define PYBIND11_DECL_CHAR_FMT \
902
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
903
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
904
905
906
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
907

908
909
910
911
912
913
914
915
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
    static PYBIND11_DESCR name() { return base_descr::name(); }
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

916
917
struct field_descriptor {
    const char *name;
918
    size_t offset;
919
    size_t size;
920
    size_t alignment;
921
    std::string format;
922
    dtype descr;
923
924
};

925
926
927
inline PYBIND11_NOINLINE void register_structured_dtype(
    const std::initializer_list<field_descriptor>& fields,
    const std::type_info& tinfo, size_t itemsize,
928
929
    bool (*direct_converter)(PyObject *, void *&)) {

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

    list names, formats, offsets;
    for (auto field : fields) {
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
    std::vector<field_descriptor> ordered_fields(fields);
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
    size_t offset = 0;
    std::ostringstream oss;
    oss << "T{";
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
961
        // mark unaligned fields with '^' (unaligned native type)
962
        if (field.offset % field.alignment)
963
            oss << '^';
964
        oss << field.format << ':' << field.name << ':';
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

983
984
985
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

986
    static PYBIND11_DESCR name() { return make_caster<T>::name(); }
987

988
    static pybind11::dtype dtype() {
989
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
990
991
    }

992
    static std::string format() {
993
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
994
        return format_str;
995
996
    }

997
998
999
    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
                                  sizeof(T), &direct_converter);
1000
1001
1002
    }

private:
1003
1004
1005
1006
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
1007

1008
1009
1010
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
1011
            return false;
1012
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1013
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1014
1015
1016
1017
1018
1019
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
1020
1021
};

1022
1023
1024
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
1025
        alignof(decltype(std::declval<T>().Field)),                                           \
1026
1027
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1028
    }
1029

1030
1031
1032
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

1033
1034
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
1048
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
1049
1050
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1051
#else
1052
1053
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1054
#endif
1055
1056
1057
1058
1059
1060
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1061
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1062
1063
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1064

1065
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
1066
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1067
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1068

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
#ifdef _MSC_VER
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
        ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

1134
template <size_t N> class multi_array_iterator {
1135
1136
1137
public:
    using container_type = std::vector<size_t>;

1138
1139
1140
1141
1142
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

1143
        // Manual copy to avoid conversion warning if using std::copy
1144
        for (size_t i = 0; i < shape.size(); ++i)
1145
1146
1147
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
1148
        for (size_t i = 0; i < N; ++i)
1149
1150
1151
1152
1153
1154
1155
1156
1157
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
1158
            } else {
1159
1160
1161
1162
1163
1164
                m_index[i] = 0;
            }
        }
        return *this;
    }

1165
    template <size_t K, class T> const T& data() const {
1166
1167
1168
1169
1170
1171
1172
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

1173
1174
1175
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
1176
1177
1178
1179
1180
1181
1182
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1183
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1198
        for (auto &iter : m_common_iterator)
1199
1200
1201
1202
1203
1204
1205
1206
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

1207
1208
1209
1210
1211
1212
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };

// Populates the shape and number of dimensions for the set of buffers.  Returns a broadcast_trivial
// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
// buffer; returns `non_trivial` otherwise.
1213
template <size_t N>
1214
broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, size_t &ndim, std::vector<size_t> &shape) {
1215
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
1216
1217
1218
        return std::max(res, buf.ndim);
    });

1219
1220
1221
    shape.clear();
    shape.resize(ndim, 1);

1222
1223
    // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
    // the full size).
1224
1225
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
1226
1227
        auto end = buffers[i].shape.rend();
        for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
1228
1229
1230
1231
1232
1233
1234
            const auto &dim_size_in = *shape_iter;
            auto &dim_size_out = *res_iter;

            // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
            if (dim_size_out == 1)
                dim_size_out = dim_size_in;
            else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1235
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1236
1237
        }
    }
1238

1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    bool trivial_broadcast_c = true;
    bool trivial_broadcast_f = true;
    for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
        if (buffers[i].size == 1)
            continue;

        // Require the same number of dimensions:
        if (buffers[i].ndim != ndim)
            return broadcast_trivial::non_trivial;

        // Require all dimensions be full-size:
        if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
            return broadcast_trivial::non_trivial;

        // Check for C contiguity (but only if previous inputs were also C contiguous)
        if (trivial_broadcast_c) {
            size_t expect_stride = buffers[i].itemsize;
            auto end = buffers[i].shape.crend();
            for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
                    trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_c = false;
1263
            }
1264
        }
1265

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        // Check for Fortran contiguity (if previous inputs were also F contiguous)
        if (trivial_broadcast_f) {
            size_t expect_stride = buffers[i].itemsize;
            auto end = buffers[i].shape.cend();
            for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
                    trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_f = false;
            }
1277
1278
        }
    }
1279
1280
1281
1282
1283

    return
        trivial_broadcast_c ? broadcast_trivial::c_trivial :
        trivial_broadcast_f ? broadcast_trivial::f_trivial :
        broadcast_trivial::non_trivial;
1284
1285
}

1286
1287
1288
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;
1289
    static constexpr size_t N = sizeof...(Args);
1290

1291
    template <typename T>
1292
    explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1293

1294
1295
    object operator()(array_t<Args, array::forcecast>... args) {
        return run(args..., make_index_sequence<N>());
1296
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1297

1298
    template <size_t ... Index> object run(array_t<Args, array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
1299
1300
1301
1302
        /* Request buffers from all parameters */
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
1303
        size_t ndim = 0;
1304
        std::vector<size_t> shape(0);
1305
        auto trivial = broadcast(buffers, ndim, shape);
1306

1307
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
1308
1309
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
            if (trivial == broadcast_trivial::f_trivial) {
                strides[0] = sizeof(Return);
                for (size_t i = 1; i < ndim; ++i) {
                    strides[i] = strides[i - 1] * shape[i - 1];
                    size *= shape[i - 1];
                }
                size *= shape[ndim - 1];
            }
            else {
                strides[ndim-1] = sizeof(Return);
                for (size_t i = ndim - 1; i > 0; --i) {
                    strides[i - 1] = strides[i] * shape[i];
                    size *= shape[i];
                }
                size *= shape[0];
1325
            }
Wenzel Jakob's avatar
Wenzel Jakob committed
1326
1327
        }

1328
        if (size == 1)
1329
            return cast(f(*reinterpret_cast<Args *>(buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1330

1331
        array_t<Return> result(shape, strides);
1332
1333
        auto buf = result.request();
        auto output = (Return *) buf.ptr;
1334

1335
1336
1337
1338
        /* Call the function */
        if (trivial == broadcast_trivial::non_trivial) {
            apply_broadcast<Index...>(buffers, buf, index);
        } else {
1339
1340
            for (size_t i = 0; i < size; ++i)
                output[i] = f((reinterpret_cast<Args *>(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...);
1341
        }
1342
1343

        return result;
1344
    }
1345

1346
    template <size_t... Index>
1347
1348
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
1349
1350
1351
1352
1353
1354
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

1355
1356
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
1357
1358
1359
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
1360
1361
};

1362
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1363
1364
1365
    static PYBIND11_DESCR name() {
        return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
    }
1366
1367
};

1368
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1369

1370
template <typename Func, typename Return, typename... Args>
1371
detail::vectorize_helper<Func, Return, Args...>
1372
vectorize(const Func &f, Return (*) (Args ...)) {
1373
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1374
1375
}

1376
1377
1378
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...>
vectorize(Return (*f) (Args ...)) {
1379
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1380
1381
}

1382
template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type>
1383
auto vectorize(Func &&f) -> decltype(
1384
1385
        vectorize(std::forward<Func>(f), (FuncType *) nullptr)) {
    return vectorize(std::forward<Func>(f), (FuncType *) nullptr);
Wenzel Jakob's avatar
Wenzel Jakob committed
1386
1387
}

1388
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
1389
1390
1391
1392

#if defined(_MSC_VER)
#pragma warning(pop)
#endif