numpy.h 43 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <string>
21
#include <initializer_list>
22
#include <functional>
23
#include <utility>
24
#include <typeindex>
25

Wenzel Jakob's avatar
Wenzel Jakob committed
26
#if defined(_MSC_VER)
27
28
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
29
30
#endif

31
32
33
34
35
36
/* This will be true on all flat address space platforms and allows us to reduce the
   whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");

37
NAMESPACE_BEGIN(pybind11)
38
NAMESPACE_BEGIN(detail)
39
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

67
68
69
70
71
72
73
74
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

75
76
77
78
79
80
81
82
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

83
84
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
85
86
87
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
88
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
89
90
        return nullptr;
    }
91
92
93
94

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
95
96
};

Ivan Smirnov's avatar
Ivan Smirnov committed
97
98
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
99
100
101
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
102
103
104
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
105
106
107
    return *ptr;
}

108
109
110
111
struct npy_api {
    enum constants {
        NPY_C_CONTIGUOUS_ = 0x0001,
        NPY_F_CONTIGUOUS_ = 0x0002,
112
        NPY_ARRAY_OWNDATA_ = 0x0004,
113
114
        NPY_ARRAY_FORCECAST_ = 0x0010,
        NPY_ENSURE_ARRAY_ = 0x0040,
115
116
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

134
135
136
137
138
139
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
140
141
142
143
144
145
146
147

    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
148
    PyTypeObject *PyVoidArrType_Type_;
149
    PyTypeObject *PyArrayDescr_Type_;
150
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
151
152
153
154
155
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
156
    PyObject *(*PyArray_Squeeze_)(PyObject *);
157
158
159
private:
    enum functions {
        API_PyArray_Type = 2,
160
        API_PyArrayDescr_Type = 3,
161
        API_PyVoidArrType_Type = 39,
162
        API_PyArray_DescrFromType = 45,
163
        API_PyArray_DescrFromScalar = 57,
164
165
166
167
168
169
170
        API_PyArray_FromAny = 69,
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
171
        API_PyArray_Squeeze = 136
172
173
174
175
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
176
        auto c = m.attr("_ARRAY_API");
177
#if PY_MAJOR_VERSION >= 3
178
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
179
#else
180
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
181
#endif
182
        npy_api api;
183
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
184
        DECL_NPY_API(PyArray_Type);
185
        DECL_NPY_API(PyVoidArrType_Type);
186
        DECL_NPY_API(PyArrayDescr_Type);
187
        DECL_NPY_API(PyArray_DescrFromType);
188
        DECL_NPY_API(PyArray_DescrFromScalar);
189
190
191
192
193
194
195
        DECL_NPY_API(PyArray_FromAny);
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
196
        DECL_NPY_API(PyArray_Squeeze);
197
#undef DECL_NPY_API
198
199
200
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
201

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

222
223
224
225
226
227
228
229
230
231
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

template <typename T> using is_pod_struct = all_of<
    std::is_pod<T>, // since we're accessing directly in memory we need a POD type
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

232
NAMESPACE_END(detail)
233

234
class dtype : public object {
235
public:
236
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
237

238
    explicit dtype(const buffer_info &info) {
239
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
240
241
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
242
    }
243

244
    explicit dtype(const std::string &format) {
245
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
246
247
    }

248
    dtype(const char *format) : dtype(std::string(format)) { }
249

250
251
252
253
254
    dtype(list names, list formats, list offsets, size_t itemsize) {
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
255
        args["itemsize"] = pybind11::int_(itemsize);
256
257
258
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
259
    /// This is essentially the same as calling numpy.dtype(args) in Python.
260
261
262
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
263
            throw error_already_set();
264
        return reinterpret_steal<dtype>(ptr);
265
    }
266

Ivan Smirnov's avatar
Ivan Smirnov committed
267
    /// Return dtype associated with a C++ type.
268
    template <typename T> static dtype of() {
269
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
270
    }
271

Ivan Smirnov's avatar
Ivan Smirnov committed
272
    /// Size of the data type in bytes.
273
    size_t itemsize() const {
274
        return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
275
276
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
277
    /// Returns true for structured data types.
278
    bool has_fields() const {
279
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
280
281
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
282
    /// Single-character type code.
283
    char kind() const {
284
        return detail::array_descriptor_proxy(m_ptr)->kind;
285
286
287
    }

private:
288
289
290
    static object _dtype_from_pep3118() {
        static PyObject *obj = module::import("numpy.core._internal")
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
291
        return reinterpret_borrow<object>(obj);
292
    }
293

294
    dtype strip_padding(size_t itemsize) {
295
296
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
297
        if (!has_fields())
298
            return *this;
299

300
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
301
302
        std::vector<field_descr> field_descriptors;

303
        for (auto field : attr("fields").attr("items")()) {
304
            auto spec = field.cast<tuple>();
305
            auto name = spec[0].cast<pybind11::str>();
306
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
307
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
308
            if (!len(name) && format.kind() == 'V')
309
                continue;
310
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
311
312
313
314
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
315
                      return a.offset.cast<int>() < b.offset.cast<int>();
316
317
318
319
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
320
321
322
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
323
        }
324
        return dtype(names, formats, offsets, itemsize);
325
326
    }
};
327

328
329
class array : public buffer {
public:
330
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
331
332
333
334
335
336
337

    enum {
        c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

338
339
    array() : array(0, static_cast<const double *>(nullptr)) {}

340
341
342
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const std::vector<size_t> &strides, const void *ptr = nullptr,
          handle base = handle()) {
343
        auto& api = detail::npy_api::get();
344
345
346
347
        auto ndim = shape.size();
        if (shape.size() != strides.size())
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
348
349
350

        int flags = 0;
        if (base && ptr) {
351
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
352
                /* Copy flags from base (except ownership bit) */
353
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
354
355
356
357
358
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

359
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
360
361
362
363
            api.PyArray_Type_, descr.release().ptr(), (int) ndim,
            reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
            reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
            const_cast<void *>(ptr), flags, nullptr));
364
365
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
366
367
        if (ptr) {
            if (base) {
368
                detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
369
            } else {
370
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
371
372
            }
        }
373
374
375
        m_ptr = tmp.release().ptr();
    }

376
377
378
    array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
          const void *ptr = nullptr, handle base = handle())
        : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
379

380
381
382
    array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
          handle base = handle())
        : array(dt, std::vector<size_t>{ count }, ptr, base) { }
383
384

    template<typename T> array(const std::vector<size_t>& shape,
385
386
                               const std::vector<size_t>& strides,
                               const T* ptr, handle base = handle())
387
    : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
388

389
390
391
392
    template <typename T>
    array(const std::vector<size_t> &shape, const T *ptr,
          handle base = handle())
        : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
393

394
395
396
    template <typename T>
    array(size_t count, const T *ptr, handle base = handle())
        : array(std::vector<size_t>{ count }, ptr, base) { }
397

398
    explicit array(const buffer_info &info)
399
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
400

401
402
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
403
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
404
405
406
407
408
409
410
411
412
    }

    /// Total number of elements
    size_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
    }

    /// Byte size of a single element
    size_t itemsize() const {
413
        return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
414
415
416
417
418
419
420
421
422
    }

    /// Total number of bytes
    size_t nbytes() const {
        return size() * itemsize();
    }

    /// Number of dimensions
    size_t ndim() const {
423
        return (size_t) detail::array_proxy(m_ptr)->nd;
424
425
    }

426
427
    /// Base object
    object base() const {
428
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
429
430
    }

431
432
    /// Dimensions of the array
    const size_t* shape() const {
433
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
434
435
436
437
438
    }

    /// Dimension along a given axis
    size_t shape(size_t dim) const {
        if (dim >= ndim())
439
            fail_dim_check(dim, "invalid axis");
440
441
442
443
444
        return shape()[dim];
    }

    /// Strides of the array
    const size_t* strides() const {
445
        return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
446
447
448
449
450
    }

    /// Stride along a given axis
    size_t strides(size_t dim) const {
        if (dim >= ndim())
451
            fail_dim_check(dim, "invalid axis");
452
453
454
        return strides()[dim];
    }

455
456
    /// Return the NumPy array flags
    int flags() const {
457
        return detail::array_proxy(m_ptr)->flags;
458
459
    }

460
461
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
462
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
463
464
465
466
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
467
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
468
469
    }

470
471
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
472
    template<typename... Ix> const void* data(Ix... index) const {
473
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
474
475
    }

476
477
478
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
479
    template<typename... Ix> void* mutable_data(Ix... index) {
480
        check_writeable();
481
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
482
483
484
485
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
486
    template<typename... Ix> size_t offset_at(Ix... index) const {
487
488
        if (sizeof...(index) > ndim())
            fail_dim_check(sizeof...(index), "too many indices for an array");
489
        return byte_offset(size_t(index)...);
490
491
492
493
494
495
    }

    size_t offset_at() const { return 0; }

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
496
    template<typename... Ix> size_t index_at(Ix... index) const {
497
        return offset_at(index...) / itemsize();
498
499
    }

500
501
502
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
503
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
504
505
    }

506
    /// Ensure that the argument is a NumPy array
507
508
509
510
511
512
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
513
514
    }

515
protected:
516
517
518
519
520
521
522
    template<typename, typename> friend struct detail::npy_format_descriptor;

    void fail_dim_check(size_t dim, const std::string& msg) const {
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

523
524
525
526
527
528
529
    template<typename... Ix> size_t byte_offset(Ix... index) const {
        check_dimensions(index...);
        return byte_offset_unsafe(index...);
    }

    template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
        return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
530
531
    }

532
    template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
533
534
535
536
537

    void check_writeable() const {
        if (!writeable())
            throw std::runtime_error("array is not writeable");
    }
538
539
540
541
542
543
544
545
546
547
548
549

    static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
        auto ndim = shape.size();
        std::vector<size_t> strides(ndim);
        if (ndim) {
            std::fill(strides.begin(), strides.end(), itemsize);
            for (size_t i = 0; i < ndim - 1; i++)
                for (size_t j = 0; j < ndim - 1 - i; j++)
                    strides[j] *= shape[ndim - 1 - i];
        }
        return strides;
    }
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564

    template<typename... Ix> void check_dimensions(Ix... index) const {
        check_dimensions_impl(size_t(0), shape(), size_t(index)...);
    }

    void check_dimensions_impl(size_t, const size_t*) const { }

    template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
565
566
567
568
569
570
571
572

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
        if (ptr == nullptr)
            return nullptr;
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
573
574
};

575
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
576
public:
577
578
579
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
    array_t(handle h, borrowed_t) : array(h, borrowed) { }
    array_t(handle h, stolen_t) : array(h, stolen) { }
580

581
582
583
584
585
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
586

587
588
589
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
        if (!m_ptr) throw error_already_set();
    }
590

591
    explicit array_t(const buffer_info& info) : array(info) { }
592

593
594
595
596
    array_t(const std::vector<size_t> &shape,
            const std::vector<size_t> &strides, const T *ptr = nullptr,
            handle base = handle())
        : array(shape, strides, ptr, base) { }
597

598
    explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
599
600
            handle base = handle())
        : array(shape, ptr, base) { }
601

602
    explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
603
        : array(count, ptr, base) { }
604

605
606
    constexpr size_t itemsize() const {
        return sizeof(T);
607
608
    }

609
    template<typename... Ix> size_t index_at(Ix... index) const {
610
611
612
        return offset_at(index...) / itemsize();
    }

613
    template<typename... Ix> const T* data(Ix... index) const {
614
615
616
        return static_cast<const T*>(array::data(index...));
    }

617
    template<typename... Ix> T* mutable_data(Ix... index) {
618
619
620
621
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
622
    template<typename... Ix> const T& at(Ix... index) const {
623
624
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
625
        return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
626
627
628
    }

    // Mutable reference to element at a given index
629
    template<typename... Ix> T& mutable_at(Ix... index) {
630
631
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
632
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
633
    }
634

635
636
637
638
    /// Ensure that the argument is a NumPy array of the correct dtype.
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
639
640
        if (!result)
            PyErr_Clear();
641
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
642
    }
643

Wenzel Jakob's avatar
Wenzel Jakob committed
644
    static bool check_(handle h) {
645
646
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
647
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
648
649
650
651
652
653
654
655
656
657
658
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
        if (ptr == nullptr)
            return nullptr;
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
            detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
659
660
};

661
template <typename T>
662
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
663
664
665
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
666
667
668
};

template <size_t N> struct format_descriptor<char[N]> {
669
    static std::string format() { return std::to_string(N) + "s"; }
670
671
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
672
    static std::string format() { return std::to_string(N) + "s"; }
673
674
};

675
676
677
678
679
680
681
682
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

683
NAMESPACE_BEGIN(detail)
684
685
686
687
688
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

    bool load(handle src, bool /* convert */) {
689
        value = type::ensure(src);
690
691
692
693
694
695
696
697
698
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

699
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
700
private:
701
702
703
704
705
706
707
708
709
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,    npy_api::NPY_UINT_,    npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

710
public:
711
712
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

713
    static pybind11::dtype dtype() {
714
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
715
            return reinterpret_borrow<pybind11::dtype>(ptr);
716
        pybind11_fail("Unsupported buffer format!");
717
    }
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, bool>::value>(_("bool"),
            _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
    }
    template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
                _("float") + _<sizeof(T)*8>(), _("longdouble"));
    }
    template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
                _("complex") + _<sizeof(T2::value_type)*16>(), _("longcomplex"));
    }
733
};
734
735

#define PYBIND11_DECL_CHAR_FMT \
736
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
737
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
738
739
740
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
741

742
743
744
745
746
747
748
749
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
    static PYBIND11_DESCR name() { return base_descr::name(); }
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

750
751
struct field_descriptor {
    const char *name;
752
    size_t offset;
753
    size_t size;
754
    size_t alignment;
755
    std::string format;
756
    dtype descr;
757
758
};

759
760
761
inline PYBIND11_NOINLINE void register_structured_dtype(
    const std::initializer_list<field_descriptor>& fields,
    const std::type_info& tinfo, size_t itemsize,
762
763
    bool (*direct_converter)(PyObject *, void *&)) {

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

    list names, formats, offsets;
    for (auto field : fields) {
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
    std::vector<field_descriptor> ordered_fields(fields);
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
    size_t offset = 0;
    std::ostringstream oss;
    oss << "T{";
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
795
        // mark unaligned fields with '^' (unaligned native type)
796
        if (field.offset % field.alignment)
797
            oss << '^';
798
        oss << field.format << ':' << field.name << ':';
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

817
818
819
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

820
    static PYBIND11_DESCR name() { return _("struct"); }
821

822
    static pybind11::dtype dtype() {
823
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
824
825
    }

826
    static std::string format() {
827
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
828
        return format_str;
829
830
    }

831
832
833
    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
                                  sizeof(T), &direct_converter);
834
835
836
    }

private:
837
838
839
840
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
841

842
843
844
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
845
            return false;
846
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
847
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
848
849
850
851
852
853
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
854
855
};

856
857
858
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
859
        alignof(decltype(std::declval<T>().Field)),                                           \
860
861
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
862
    }
863

864
865
866
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

867
868
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
869
870
871
872
873
874
875
876
877
878
879
880
881
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
882
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
883
884
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
885
#else
886
887
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
888
#endif
889
890
891
892
893
894
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
895
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
896
897
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
898

899
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
900
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
901
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
902

903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
#ifdef _MSC_VER
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
        ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
944

945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

968
template <size_t N> class multi_array_iterator {
969
970
971
public:
    using container_type = std::vector<size_t>;

972
973
974
975
976
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

977
        // Manual copy to avoid conversion warning if using std::copy
978
        for (size_t i = 0; i < shape.size(); ++i)
979
980
981
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
982
        for (size_t i = 0; i < N; ++i)
983
984
985
986
987
988
989
990
991
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
992
            } else {
993
994
995
996
997
998
                m_index[i] = 0;
            }
        }
        return *this;
    }

999
    template <size_t K, class T> const T& data() const {
1000
1001
1002
1003
1004
1005
1006
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

1007
1008
1009
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
1010
1011
1012
1013
1014
1015
1016
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1017
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1032
        for (auto &iter : m_common_iterator)
1033
1034
1035
1036
1037
1038
1039
1040
1041
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

template <size_t N>
1042
1043
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
1044
1045
1046
        return std::max(res, buf.ndim);
    });

1047
    shape = std::vector<size_t>(ndim, 1);
1048
1049
1050
1051
    bool trivial_broadcast = true;
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
1052
1053
1054
1055
        for (auto shape_iter = buffers[i].shape.rbegin();
             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {

            if (*res_iter == 1)
1056
                *res_iter = *shape_iter;
1057
            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
1058
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1059

1060
1061
1062
1063
1064
1065
1066
            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
        }
        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
    }
    return trivial_broadcast;
}

1067
1068
1069
1070
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

1071
    template <typename T>
1072
    explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1073

1074
    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1075
        return run(args..., make_index_sequence<sizeof...(Args)>());
1076
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1077

1078
    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
1079
        /* Request buffers from all parameters */
1080
        const size_t N = sizeof...(Args);
1081

Wenzel Jakob's avatar
Wenzel Jakob committed
1082
1083
1084
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
1085
        size_t ndim = 0;
1086
1087
        std::vector<size_t> shape(0);
        bool trivial_broadcast = broadcast(buffers, ndim, shape);
1088

1089
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
1090
1091
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
1092
            strides[ndim-1] = sizeof(Return);
1093
            for (size_t i = ndim - 1; i > 0; --i) {
1094
1095
1096
1097
                strides[i - 1] = strides[i] * shape[i];
                size *= shape[i];
            }
            size *= shape[0];
Wenzel Jakob's avatar
Wenzel Jakob committed
1098
1099
        }

1100
        if (size == 1)
1101
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1102

1103
1104
1105
        array_t<Return> result(shape, strides);
        auto buf = result.request();
        auto output = (Return *) buf.ptr;
1106

1107
        if (trivial_broadcast) {
1108
            /* Call the function */
1109
            for (size_t i = 0; i < size; ++i) {
1110
                output[i] = f((buffers[Index].size == 1
1111
1112
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);
1113
            }
1114
        } else {
1115
1116
            apply_broadcast<N, Index...>(buffers, buf, index);
        }
1117
1118

        return result;
1119
    }
1120
1121

    template <size_t N, size_t... Index>
1122
1123
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
1124
1125
1126
1127
1128
1129
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

1130
1131
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
1132
1133
1134
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
1135
1136
};

1137
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1138
    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + make_caster<T>::name() + _("]"); }
1139
1140
};

1141
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1142

1143
1144
1145
template <typename Func, typename Return, typename... Args /*,*/ PYBIND11_NOEXCEPT_TPL_ARG>
detail::vectorize_helper<Func, Return, Args...>
vectorize(const Func &f, Return (*) (Args ...) PYBIND11_NOEXCEPT_SPECIFIER) {
1146
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1147
1148
}

1149
1150
1151
template <typename Return, typename... Args /*,*/ PYBIND11_NOEXCEPT_TPL_ARG>
detail::vectorize_helper<Return (*) (Args ...) PYBIND11_NOEXCEPT_SPECIFIER, Return, Args...>
vectorize(Return (*f) (Args ...) PYBIND11_NOEXCEPT_SPECIFIER) {
1152
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1153
1154
}

1155
1156
1157
1158
1159
template <typename Func>
auto vectorize(Func &&f) -> decltype(
        vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type *) nullptr)) {
    return vectorize(std::forward<Func>(f), (typename detail::remove_class<decltype(
                   &std::remove_reference<Func>::type::operator())>::type *) nullptr);
Wenzel Jakob's avatar
Wenzel Jakob committed
1160
1161
}

1162
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
1163
1164
1165
1166

#if defined(_MSC_VER)
#pragma warning(pop)
#endif