numpy.h 43.1 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <string>
21
#include <initializer_list>
22
#include <functional>
23
#include <utility>
24
#include <typeindex>
25

Wenzel Jakob's avatar
Wenzel Jakob committed
26
#if defined(_MSC_VER)
27
28
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
29
30
#endif

31
32
33
34
35
36
/* This will be true on all flat address space platforms and allows us to reduce the
   whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");

37
NAMESPACE_BEGIN(pybind11)
38
NAMESPACE_BEGIN(detail)
39
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

67
68
69
70
71
72
73
74
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

75
76
77
78
79
80
81
82
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

83
84
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
85
86
87
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
88
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
89
90
        return nullptr;
    }
91
92
93
94

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
95
96
};

Ivan Smirnov's avatar
Ivan Smirnov committed
97
98
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
99
100
101
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
102
103
104
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
105
106
107
    return *ptr;
}

108
109
struct npy_api {
    enum constants {
110
111
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
112
        NPY_ARRAY_OWNDATA_ = 0x0004,
113
        NPY_ARRAY_FORCECAST_ = 0x0010,
114
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
115
116
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

134
135
136
137
138
139
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
140
141
142
143
144
145
146
147

    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
148
    PyTypeObject *PyVoidArrType_Type_;
149
    PyTypeObject *PyArrayDescr_Type_;
150
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
151
152
153
154
155
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
156
    PyObject *(*PyArray_Squeeze_)(PyObject *);
Jason Rhinelander's avatar
Jason Rhinelander committed
157
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
158
159
160
private:
    enum functions {
        API_PyArray_Type = 2,
161
        API_PyArrayDescr_Type = 3,
162
        API_PyVoidArrType_Type = 39,
163
        API_PyArray_DescrFromType = 45,
164
        API_PyArray_DescrFromScalar = 57,
165
166
167
168
169
170
171
        API_PyArray_FromAny = 69,
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
172
173
        API_PyArray_Squeeze = 136,
        API_PyArray_SetBaseObject = 282
174
175
176
177
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
178
        auto c = m.attr("_ARRAY_API");
179
#if PY_MAJOR_VERSION >= 3
180
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
181
#else
182
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
183
#endif
184
        npy_api api;
185
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
186
        DECL_NPY_API(PyArray_Type);
187
        DECL_NPY_API(PyVoidArrType_Type);
188
        DECL_NPY_API(PyArrayDescr_Type);
189
        DECL_NPY_API(PyArray_DescrFromType);
190
        DECL_NPY_API(PyArray_DescrFromScalar);
191
192
193
194
195
196
197
        DECL_NPY_API(PyArray_FromAny);
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
198
        DECL_NPY_API(PyArray_Squeeze);
Jason Rhinelander's avatar
Jason Rhinelander committed
199
        DECL_NPY_API(PyArray_SetBaseObject);
200
#undef DECL_NPY_API
201
202
203
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

225
226
227
228
229
230
231
232
233
234
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

template <typename T> using is_pod_struct = all_of<
    std::is_pod<T>, // since we're accessing directly in memory we need a POD type
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

235
NAMESPACE_END(detail)
236

237
class dtype : public object {
238
public:
239
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
240

241
    explicit dtype(const buffer_info &info) {
242
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
243
244
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
245
    }
246

247
    explicit dtype(const std::string &format) {
248
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
249
250
    }

251
    dtype(const char *format) : dtype(std::string(format)) { }
252

253
254
255
256
257
    dtype(list names, list formats, list offsets, size_t itemsize) {
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
258
        args["itemsize"] = pybind11::int_(itemsize);
259
260
261
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
262
    /// This is essentially the same as calling numpy.dtype(args) in Python.
263
264
265
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
266
            throw error_already_set();
267
        return reinterpret_steal<dtype>(ptr);
268
    }
269

Ivan Smirnov's avatar
Ivan Smirnov committed
270
    /// Return dtype associated with a C++ type.
271
    template <typename T> static dtype of() {
272
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
273
    }
274

Ivan Smirnov's avatar
Ivan Smirnov committed
275
    /// Size of the data type in bytes.
276
    size_t itemsize() const {
277
        return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
278
279
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
280
    /// Returns true for structured data types.
281
    bool has_fields() const {
282
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
283
284
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
285
    /// Single-character type code.
286
    char kind() const {
287
        return detail::array_descriptor_proxy(m_ptr)->kind;
288
289
290
    }

private:
291
292
293
    static object _dtype_from_pep3118() {
        static PyObject *obj = module::import("numpy.core._internal")
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
294
        return reinterpret_borrow<object>(obj);
295
    }
296

297
    dtype strip_padding(size_t itemsize) {
298
299
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
300
        if (!has_fields())
301
            return *this;
302

303
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
304
305
        std::vector<field_descr> field_descriptors;

306
        for (auto field : attr("fields").attr("items")()) {
307
            auto spec = field.cast<tuple>();
308
            auto name = spec[0].cast<pybind11::str>();
309
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
310
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
311
            if (!len(name) && format.kind() == 'V')
312
                continue;
313
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
314
315
316
317
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
318
                      return a.offset.cast<int>() < b.offset.cast<int>();
319
320
321
322
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
323
324
325
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
326
        }
327
        return dtype(names, formats, offsets, itemsize);
328
329
    }
};
330

331
332
class array : public buffer {
public:
333
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
334
335

    enum {
336
337
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
338
339
340
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

341
342
    array() : array(0, static_cast<const double *>(nullptr)) {}

343
344
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const std::vector<size_t> &strides, const void *ptr = nullptr,
345
          handle base = handle()) {
346
        auto& api = detail::npy_api::get();
347
348
349
350
        auto ndim = shape.size();
        if (shape.size() != strides.size())
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
351
352
353

        int flags = 0;
        if (base && ptr) {
354
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
355
                /* Copy flags from base (except ownership bit) */
356
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
357
358
359
360
361
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

362
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
363
364
365
366
            api.PyArray_Type_, descr.release().ptr(), (int) ndim,
            reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
            reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
            const_cast<void *>(ptr), flags, nullptr));
367
368
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
369
370
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
371
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
372
            } else {
373
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
374
375
            }
        }
376
377
378
        m_ptr = tmp.release().ptr();
    }

379
380
381
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const void *ptr = nullptr, handle base = handle())
        : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
382

383
384
385
    array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
          handle base = handle())
        : array(dt, std::vector<size_t>{ count }, ptr, base) { }
386
387

    template<typename T> array(const std::vector<size_t>& shape,
388
389
                               const std::vector<size_t>& strides,
                               const T* ptr, handle base = handle())
390
    : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
391

392
393
394
395
    template <typename T>
    array(const std::vector<size_t> &shape, const T *ptr,
          handle base = handle())
        : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
396

397
398
399
    template <typename T>
    array(size_t count, const T *ptr, handle base = handle())
        : array(std::vector<size_t>{ count }, ptr, base) { }
400

401
    explicit array(const buffer_info &info)
402
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
403

404
405
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
406
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
407
408
409
410
411
412
413
414
415
    }

    /// Total number of elements
    size_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
    }

    /// Byte size of a single element
    size_t itemsize() const {
416
        return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
417
418
419
420
421
422
423
424
425
    }

    /// Total number of bytes
    size_t nbytes() const {
        return size() * itemsize();
    }

    /// Number of dimensions
    size_t ndim() const {
426
        return (size_t) detail::array_proxy(m_ptr)->nd;
427
428
    }

429
430
    /// Base object
    object base() const {
431
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
432
433
    }

434
435
    /// Dimensions of the array
    const size_t* shape() const {
436
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
437
438
439
440
441
    }

    /// Dimension along a given axis
    size_t shape(size_t dim) const {
        if (dim >= ndim())
442
            fail_dim_check(dim, "invalid axis");
443
444
445
446
447
        return shape()[dim];
    }

    /// Strides of the array
    const size_t* strides() const {
448
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
449
450
451
452
453
    }

    /// Stride along a given axis
    size_t strides(size_t dim) const {
        if (dim >= ndim())
454
            fail_dim_check(dim, "invalid axis");
455
456
457
        return strides()[dim];
    }

458
459
    /// Return the NumPy array flags
    int flags() const {
460
        return detail::array_proxy(m_ptr)->flags;
461
462
    }

463
464
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
465
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
466
467
468
469
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
470
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
471
472
    }

473
474
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
475
    template<typename... Ix> const void* data(Ix... index) const {
476
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
477
478
    }

479
480
481
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
482
    template<typename... Ix> void* mutable_data(Ix... index) {
483
        check_writeable();
484
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
485
486
487
488
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
489
    template<typename... Ix> size_t offset_at(Ix... index) const {
490
491
        if (sizeof...(index) > ndim())
            fail_dim_check(sizeof...(index), "too many indices for an array");
492
        return byte_offset(size_t(index)...);
493
494
495
496
497
498
    }

    size_t offset_at() const { return 0; }

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
499
    template<typename... Ix> size_t index_at(Ix... index) const {
500
        return offset_at(index...) / itemsize();
501
502
    }

503
504
505
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
506
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
507
508
    }

509
    /// Ensure that the argument is a NumPy array
510
511
512
513
514
515
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
516
517
    }

518
protected:
519
520
521
522
523
524
525
    template<typename, typename> friend struct detail::npy_format_descriptor;

    void fail_dim_check(size_t dim, const std::string& msg) const {
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

526
527
528
529
530
531
532
    template<typename... Ix> size_t byte_offset(Ix... index) const {
        check_dimensions(index...);
        return byte_offset_unsafe(index...);
    }

    template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
        return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
533
534
    }

535
    template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
536
537
538

    void check_writeable() const {
        if (!writeable())
539
            throw std::domain_error("array is not writeable");
540
    }
541

542
    static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
543
544
545
546
547
548
549
550
551
552
        auto ndim = shape.size();
        std::vector<size_t> strides(ndim);
        if (ndim) {
            std::fill(strides.begin(), strides.end(), itemsize);
            for (size_t i = 0; i < ndim - 1; i++)
                for (size_t j = 0; j < ndim - 1 - i; j++)
                    strides[j] *= shape[ndim - 1 - i];
        }
        return strides;
    }
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567

    template<typename... Ix> void check_dimensions(Ix... index) const {
        check_dimensions_impl(size_t(0), shape(), size_t(index)...);
    }

    void check_dimensions_impl(size_t, const size_t*) const { }

    template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
568
569
570
571
572
573

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
        if (ptr == nullptr)
            return nullptr;
        return detail::npy_api::get().PyArray_FromAny_(
574
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
575
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
576
577
};

578
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
579
public:
580
581
582
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
    array_t(handle h, borrowed_t) : array(h, borrowed) { }
    array_t(handle h, stolen_t) : array(h, stolen) { }
583

584
585
586
587
588
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
589

590
591
592
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
        if (!m_ptr) throw error_already_set();
    }
593

594
    explicit array_t(const buffer_info& info) : array(info) { }
595

596
597
598
599
    array_t(const std::vector<size_t> &shape,
            const std::vector<size_t> &strides, const T *ptr = nullptr,
            handle base = handle())
        : array(shape, strides, ptr, base) { }
600

601
    explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
602
603
            handle base = handle())
        : array(shape, ptr, base) { }
604

605
    explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
606
        : array(count, ptr, base) { }
607

608
609
    constexpr size_t itemsize() const {
        return sizeof(T);
610
611
    }

612
    template<typename... Ix> size_t index_at(Ix... index) const {
613
614
615
        return offset_at(index...) / itemsize();
    }

616
    template<typename... Ix> const T* data(Ix... index) const {
617
618
619
        return static_cast<const T*>(array::data(index...));
    }

620
    template<typename... Ix> T* mutable_data(Ix... index) {
621
622
623
624
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
625
    template<typename... Ix> const T& at(Ix... index) const {
626
627
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
628
        return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
629
630
631
    }

    // Mutable reference to element at a given index
632
    template<typename... Ix> T& mutable_at(Ix... index) {
633
634
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
635
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
636
    }
637

Jason Rhinelander's avatar
Jason Rhinelander committed
638
639
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
640
641
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
642
643
        if (!result)
            PyErr_Clear();
644
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
645
    }
646

Wenzel Jakob's avatar
Wenzel Jakob committed
647
    static bool check_(handle h) {
648
649
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
650
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
651
652
653
654
655
656
657
658
659
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
        if (ptr == nullptr)
            return nullptr;
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
660
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
661
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
662
663
};

664
template <typename T>
665
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
666
667
668
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
669
670
671
};

template <size_t N> struct format_descriptor<char[N]> {
672
    static std::string format() { return std::to_string(N) + "s"; }
673
674
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
675
    static std::string format() { return std::to_string(N) + "s"; }
676
677
};

678
679
680
681
682
683
684
685
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

686
NAMESPACE_BEGIN(detail)
687
688
689
690
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

691
692
693
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
694
        value = type::ensure(src);
695
696
697
698
699
700
701
702
703
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

704
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
705
private:
706
707
708
709
710
711
712
713
714
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,    npy_api::NPY_UINT_,    npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

715
public:
716
717
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

718
    static pybind11::dtype dtype() {
719
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
720
            return reinterpret_borrow<pybind11::dtype>(ptr);
721
        pybind11_fail("Unsupported buffer format!");
722
    }
723
724
725
726
727
728
729
730
731
732
733
734
735
    template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, bool>::value>(_("bool"),
            _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
    }
    template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
                _("float") + _<sizeof(T)*8>(), _("longdouble"));
    }
    template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
736
                _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
737
    }
738
};
739
740

#define PYBIND11_DECL_CHAR_FMT \
741
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
742
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
743
744
745
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
746

747
748
749
750
751
752
753
754
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
    static PYBIND11_DESCR name() { return base_descr::name(); }
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

755
756
struct field_descriptor {
    const char *name;
757
    size_t offset;
758
    size_t size;
759
    size_t alignment;
760
    std::string format;
761
    dtype descr;
762
763
};

764
765
766
inline PYBIND11_NOINLINE void register_structured_dtype(
    const std::initializer_list<field_descriptor>& fields,
    const std::type_info& tinfo, size_t itemsize,
767
768
    bool (*direct_converter)(PyObject *, void *&)) {

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

    list names, formats, offsets;
    for (auto field : fields) {
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
    std::vector<field_descriptor> ordered_fields(fields);
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
    size_t offset = 0;
    std::ostringstream oss;
    oss << "T{";
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
800
        // mark unaligned fields with '^' (unaligned native type)
801
        if (field.offset % field.alignment)
802
            oss << '^';
803
        oss << field.format << ':' << field.name << ':';
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

822
823
824
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

825
    static PYBIND11_DESCR name() { return _("struct"); }
826

827
    static pybind11::dtype dtype() {
828
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
829
830
    }

831
    static std::string format() {
832
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
833
        return format_str;
834
835
    }

836
837
838
    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
                                  sizeof(T), &direct_converter);
839
840
841
    }

private:
842
843
844
845
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
846

847
848
849
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
850
            return false;
851
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
852
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
853
854
855
856
857
858
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
859
860
};

861
862
863
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
864
        alignof(decltype(std::declval<T>().Field)),                                           \
865
866
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
867
    }
868

869
870
871
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

872
873
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
874
875
876
877
878
879
880
881
882
883
884
885
886
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
887
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
888
889
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
890
#else
891
892
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
893
#endif
894
895
896
897
898
899
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
900
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
901
902
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
903

904
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
905
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
906
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
907

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
#ifdef _MSC_VER
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
        ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})

929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
949

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

973
template <size_t N> class multi_array_iterator {
974
975
976
public:
    using container_type = std::vector<size_t>;

977
978
979
980
981
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

982
        // Manual copy to avoid conversion warning if using std::copy
983
        for (size_t i = 0; i < shape.size(); ++i)
984
985
986
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
987
        for (size_t i = 0; i < N; ++i)
988
989
990
991
992
993
994
995
996
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
997
            } else {
998
999
1000
1001
1002
1003
                m_index[i] = 0;
            }
        }
        return *this;
    }

1004
    template <size_t K, class T> const T& data() const {
1005
1006
1007
1008
1009
1010
1011
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

1012
1013
1014
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
1015
1016
1017
1018
1019
1020
1021
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1022
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1037
        for (auto &iter : m_common_iterator)
1038
1039
1040
1041
1042
1043
1044
1045
1046
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

template <size_t N>
1047
1048
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
1049
1050
1051
        return std::max(res, buf.ndim);
    });

1052
    shape = std::vector<size_t>(ndim, 1);
1053
1054
1055
1056
    bool trivial_broadcast = true;
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
1057
1058
1059
1060
        for (auto shape_iter = buffers[i].shape.rbegin();
             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {

            if (*res_iter == 1)
1061
                *res_iter = *shape_iter;
1062
            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
1063
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1064

1065
1066
1067
1068
1069
1070
1071
            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
        }
        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
    }
    return trivial_broadcast;
}

1072
1073
1074
1075
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

1076
    template <typename T>
1077
    explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1078

1079
    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1080
        return run(args..., make_index_sequence<sizeof...(Args)>());
1081
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1082

1083
    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
1084
        /* Request buffers from all parameters */
1085
        const size_t N = sizeof...(Args);
1086

Wenzel Jakob's avatar
Wenzel Jakob committed
1087
1088
1089
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
1090
        size_t ndim = 0;
1091
1092
        std::vector<size_t> shape(0);
        bool trivial_broadcast = broadcast(buffers, ndim, shape);
1093

1094
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
1095
1096
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
1097
            strides[ndim-1] = sizeof(Return);
1098
            for (size_t i = ndim - 1; i > 0; --i) {
1099
1100
1101
1102
                strides[i - 1] = strides[i] * shape[i];
                size *= shape[i];
            }
            size *= shape[0];
Wenzel Jakob's avatar
Wenzel Jakob committed
1103
1104
        }

1105
        if (size == 1)
1106
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1107

1108
1109
1110
        array_t<Return> result(shape, strides);
        auto buf = result.request();
        auto output = (Return *) buf.ptr;
1111

1112
        if (trivial_broadcast) {
1113
            /* Call the function */
1114
            for (size_t i = 0; i < size; ++i) {
1115
                output[i] = f((buffers[Index].size == 1
1116
1117
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);
1118
            }
1119
        } else {
1120
1121
            apply_broadcast<N, Index...>(buffers, buf, index);
        }
1122
1123

        return result;
1124
    }
1125
1126

    template <size_t N, size_t... Index>
1127
1128
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
1129
1130
1131
1132
1133
1134
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

1135
1136
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
1137
1138
1139
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
1140
1141
};

1142
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1143
    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + make_caster<T>::name() + _("]"); }
1144
1145
};

1146
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1147

1148
template <typename Func, typename Return, typename... Args>
1149
detail::vectorize_helper<Func, Return, Args...>
1150
vectorize(const Func &f, Return (*) (Args ...)) {
1151
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1152
1153
}

1154
1155
1156
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...>
vectorize(Return (*f) (Args ...)) {
1157
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1158
1159
}

1160
template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type>
1161
auto vectorize(Func &&f) -> decltype(
1162
1163
        vectorize(std::forward<Func>(f), (FuncType *) nullptr)) {
    return vectorize(std::forward<Func>(f), (FuncType *) nullptr);
Wenzel Jakob's avatar
Wenzel Jakob committed
1164
1165
}

1166
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
1167
1168
1169
1170

#if defined(_MSC_VER)
#pragma warning(pop)
#endif