numpy.h 64.2 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <string>
21
#include <initializer_list>
22
#include <functional>
23
#include <utility>
24
#include <typeindex>
25

Wenzel Jakob's avatar
Wenzel Jakob committed
26
#if defined(_MSC_VER)
27
28
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
29
30
#endif

31
/* This will be true on all flat address space platforms and allows us to reduce the
32
   whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
33
34
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
35
static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
36

37
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
38
39
40

class array; // Forward declaration

41
NAMESPACE_BEGIN(detail)
42
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

70
71
72
73
74
75
76
77
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

78
79
80
81
82
83
84
85
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

86
87
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
88
89
90
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
91
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
92
93
        return nullptr;
    }
94
95
96
97

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
98
99
};

Ivan Smirnov's avatar
Ivan Smirnov committed
100
101
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
102
103
104
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
105
106
107
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
108
109
110
    return *ptr;
}

111
112
struct npy_api {
    enum constants {
113
114
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
115
        NPY_ARRAY_OWNDATA_ = 0x0004,
116
        NPY_ARRAY_FORCECAST_ = 0x0010,
117
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
118
119
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
120
121
122
123
124
125
126
127
128
129
130
131
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

uentity's avatar
uentity committed
132
133
134
135
136
    typedef struct {
        Py_intptr_t *ptr;
        int len;
    } PyArray_Dims;

137
138
139
140
141
    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

142
143
144
145
146
147
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
148

149
    unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
150
151
152
153
154
    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
155
    int (*PyArray_CopyInto_)(PyObject *, PyObject *);
156
157
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
158
    PyTypeObject *PyVoidArrType_Type_;
159
    PyTypeObject *PyArrayDescr_Type_;
160
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
161
162
163
164
165
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
166
    PyObject *(*PyArray_Squeeze_)(PyObject *);
Jason Rhinelander's avatar
Jason Rhinelander committed
167
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
uentity's avatar
uentity committed
168
    PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
169
170
private:
    enum functions {
171
        API_PyArray_GetNDArrayCFeatureVersion = 211,
172
        API_PyArray_Type = 2,
173
        API_PyArrayDescr_Type = 3,
174
        API_PyVoidArrType_Type = 39,
175
        API_PyArray_DescrFromType = 45,
176
        API_PyArray_DescrFromScalar = 57,
177
        API_PyArray_FromAny = 69,
uentity's avatar
uentity committed
178
        API_PyArray_Resize = 80,
179
        API_PyArray_CopyInto = 82,
180
181
182
183
184
185
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
186
187
        API_PyArray_Squeeze = 136,
        API_PyArray_SetBaseObject = 282
188
189
190
191
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
192
        auto c = m.attr("_ARRAY_API");
193
#if PY_MAJOR_VERSION >= 3
194
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
195
#else
196
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
197
#endif
198
        npy_api api;
199
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
200
201
202
        DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
        if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
            pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
203
        DECL_NPY_API(PyArray_Type);
204
        DECL_NPY_API(PyVoidArrType_Type);
205
        DECL_NPY_API(PyArrayDescr_Type);
206
        DECL_NPY_API(PyArray_DescrFromType);
207
        DECL_NPY_API(PyArray_DescrFromScalar);
208
        DECL_NPY_API(PyArray_FromAny);
uentity's avatar
uentity committed
209
        DECL_NPY_API(PyArray_Resize);
210
        DECL_NPY_API(PyArray_CopyInto);
211
212
213
214
215
216
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
217
        DECL_NPY_API(PyArray_Squeeze);
Jason Rhinelander's avatar
Jason Rhinelander committed
218
        DECL_NPY_API(PyArray_SetBaseObject);
219
#undef DECL_NPY_API
220
221
222
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

244
245
246
247
248
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
template <typename T> struct array_info_scalar {
    typedef T type;
    static constexpr bool is_array = false;
    static constexpr bool is_empty = false;
    static PYBIND11_DESCR extents() { return _(""); }
    static void append_extents(list& /* shape */) { }
};
// Computes underlying type and a comma-separated list of extents for array
// types (any mix of std::array and built-in arrays). An array of char is
// treated as scalar because it gets special handling.
template <typename T> struct array_info : array_info_scalar<T> { };
template <typename T, size_t N> struct array_info<std::array<T, N>> {
    using type = typename array_info<T>::type;
    static constexpr bool is_array = true;
    static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
    static constexpr size_t extent = N;

    // appends the extents to shape
    static void append_extents(list& shape) {
        shape.append(N);
        array_info<T>::append_extents(shape);
    }

    template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
    static PYBIND11_DESCR extents() {
        return _<N>();
    }

    template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
    static PYBIND11_DESCR extents() {
        return concat(_<N>(), array_info<T>::extents());
    }
};
// For numpy we have special handling for arrays of characters, so we don't include
// the size in the array extents.
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
template <typename T> using remove_all_extents_t = typename array_info<T>::type;

289
template <typename T> using is_pod_struct = all_of<
290
    std::is_standard_layout<T>,     // since we're accessing directly in memory we need a standard layout type
291
292
293
#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI)
    // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent
    // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4).
294
295
296
297
298
299
    std::is_trivially_copyable<T>,
#else
    // GCC 4 doesn't implement is_trivially_copyable, so approximate it
    std::is_trivially_destructible<T>,
    satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
#endif
300
301
302
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

303
304
305
306
template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
    return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
307
308
}

309
310
/**
 * Proxy class providing unsafe, unchecked const access to array data.  This is constructed through
311
312
 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.  `Dims`
 * will be -1 for dimensions determined at runtime.
313
 */
314
template <typename T, ssize_t Dims>
315
316
class unchecked_reference {
protected:
317
    static constexpr bool Dynamic = Dims < 0;
318
319
    const unsigned char *data_;
    // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
320
    // make large performance gains on big, nested loops, but requires compile-time dimensions
321
322
323
    conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
            shape_, strides_;
    const ssize_t dims_;
324
325

    friend class pybind11::array;
326
327
    // Constructor for compile-time dimensions:
    template <bool Dyn = Dynamic>
328
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
329
    : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
330
        for (size_t i = 0; i < (size_t) dims_; i++) {
331
332
333
334
            shape_[i] = shape[i];
            strides_[i] = strides[i];
        }
    }
335
336
    // Constructor for runtime dimensions:
    template <bool Dyn = Dynamic>
337
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
338
    : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
339
340

public:
341
342
    /**
     * Unchecked const reference access to data at the given indices.  For a compile-time known
343
344
     * number of dimensions, this requires the correct number of arguments; for run-time
     * dimensionality, this is not checked (and so is up to the caller to use safely).
345
     */
346
    template <typename... Ix> const T &operator()(Ix... index) const {
Jason Rhinelander's avatar
Jason Rhinelander committed
347
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
348
                "Invalid number of indices for unchecked array reference");
349
        return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
350
    }
351
352
    /**
     * Unchecked const reference access to data; this operator only participates if the reference
353
354
     * is to a 1-dimensional array.  When present, this is exactly equivalent to `obj(index)`.
     */
355
356
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    const T &operator[](ssize_t index) const { return operator()(index); }
357

358
    /// Pointer access to the data at the given indices.
359
    template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
360
361

    /// Returns the item size, i.e. sizeof(T)
362
    constexpr static ssize_t itemsize() { return sizeof(T); }
363

364
    /// Returns the shape (i.e. size) of dimension `dim`
365
    ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
366
367

    /// Returns the number of dimensions of the array
368
    ssize_t ndim() const { return dims_; }
369
370
371

    /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
    template <bool Dyn = Dynamic>
372
373
    enable_if_t<!Dyn, ssize_t> size() const {
        return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
374
375
    }
    template <bool Dyn = Dynamic>
376
377
    enable_if_t<Dyn, ssize_t> size() const {
        return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
378
379
380
381
    }

    /// Returns the total number of bytes used by the referenced data.  Note that the actual span in
    /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
382
    ssize_t nbytes() const {
383
384
        return size() * itemsize();
    }
385
386
};

387
template <typename T, ssize_t Dims>
388
389
390
391
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
    friend class pybind11::array;
    using ConstBase = unchecked_reference<T, Dims>;
    using ConstBase::ConstBase;
392
    using ConstBase::Dynamic;
393
394
395
public:
    /// Mutable, unchecked access to data at the given indices.
    template <typename... Ix> T& operator()(Ix... index) {
Jason Rhinelander's avatar
Jason Rhinelander committed
396
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
397
                "Invalid number of indices for unchecked array reference");
398
399
        return const_cast<T &>(ConstBase::operator()(index...));
    }
400
401
    /**
     * Mutable, unchecked access data at the given index; this operator only participates if the
402
403
     * reference is to a 1-dimensional array (or has runtime dimensions).  When present, this is
     * exactly equivalent to `obj(index)`.
404
     */
405
406
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    T &operator[](ssize_t index) { return operator()(index); }
407
408

    /// Mutable pointer access to the data at the given indices.
409
    template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
410
411
};

412
template <typename T, ssize_t Dim>
413
struct type_caster<unchecked_reference<T, Dim>> {
414
    static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
415
};
416
template <typename T, ssize_t Dim>
417
418
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};

419
NAMESPACE_END(detail)
420

421
class dtype : public object {
422
public:
423
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
424

425
    explicit dtype(const buffer_info &info) {
426
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
427
428
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
429
    }
430

431
    explicit dtype(const std::string &format) {
432
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
433
434
    }

435
    dtype(const char *format) : dtype(std::string(format)) { }
436

437
    dtype(list names, list formats, list offsets, ssize_t itemsize) {
438
439
440
441
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
442
        args["itemsize"] = pybind11::int_(itemsize);
443
444
445
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
446
    /// This is essentially the same as calling numpy.dtype(args) in Python.
447
448
449
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
450
            throw error_already_set();
451
        return reinterpret_steal<dtype>(ptr);
452
    }
453

Ivan Smirnov's avatar
Ivan Smirnov committed
454
    /// Return dtype associated with a C++ type.
455
    template <typename T> static dtype of() {
456
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
457
    }
458

Ivan Smirnov's avatar
Ivan Smirnov committed
459
    /// Size of the data type in bytes.
460
461
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
462
463
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
464
    /// Returns true for structured data types.
465
    bool has_fields() const {
466
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
467
468
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
469
    /// Single-character type code.
470
    char kind() const {
471
        return detail::array_descriptor_proxy(m_ptr)->kind;
472
473
474
    }

private:
475
476
477
    static object _dtype_from_pep3118() {
        static PyObject *obj = module::import("numpy.core._internal")
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
478
        return reinterpret_borrow<object>(obj);
479
    }
480

481
    dtype strip_padding(ssize_t itemsize) {
482
483
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
484
        if (!has_fields())
485
            return *this;
486

487
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
488
489
        std::vector<field_descr> field_descriptors;

490
        for (auto field : attr("fields").attr("items")()) {
491
            auto spec = field.cast<tuple>();
492
            auto name = spec[0].cast<pybind11::str>();
493
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
494
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
495
            if (!len(name) && format.kind() == 'V')
496
                continue;
497
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
498
499
500
501
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
502
                      return a.offset.cast<int>() < b.offset.cast<int>();
503
504
505
506
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
507
508
509
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
510
        }
511
        return dtype(names, formats, offsets, itemsize);
512
513
    }
};
514

515
516
class array : public buffer {
public:
517
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
518
519

    enum {
520
521
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
522
523
524
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

525
    array() : array({{0}}, static_cast<const double *>(nullptr)) {}
526

527
528
    using ShapeContainer = detail::any_container<ssize_t>;
    using StridesContainer = detail::any_container<ssize_t>;
529
530
531
532
533
534

    // Constructs an array taking shape/strides from arbitrary container types
    array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
          const void *ptr = nullptr, handle base = handle()) {

        if (strides->empty())
535
            *strides = c_strides(*shape, dt.itemsize());
536
537
538

        auto ndim = shape->size();
        if (ndim != strides->size())
539
540
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
541
542
543

        int flags = 0;
        if (base && ptr) {
544
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
545
                /* Copy flags from base (except ownership bit) */
546
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
547
548
549
550
551
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

552
        auto &api = detail::npy_api::get();
553
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
554
            api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
555
            const_cast<void *>(ptr), flags, nullptr));
556
        if (!tmp)
557
            throw error_already_set();
558
559
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
560
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
561
            } else {
562
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
563
564
            }
        }
565
566
567
        m_ptr = tmp.release().ptr();
    }

568
569
    array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
        : array(dt, std::move(shape), {}, ptr, base) { }
570

571
572
573
    template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
    array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
        : array(dt, {{count}}, ptr, base) { }
574

575
576
577
    template <typename T>
    array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
        : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
578

579
    template <typename T>
580
581
    array(ShapeContainer shape, const T *ptr, handle base = handle())
        : array(std::move(shape), {}, ptr, base) { }
582

583
    template <typename T>
584
    explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
585

586
    explicit array(const buffer_info &info)
587
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
588

589
590
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
591
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
592
593
594
    }

    /// Total number of elements
595
596
    ssize_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
597
598
599
    }

    /// Byte size of a single element
600
601
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
602
603
604
    }

    /// Total number of bytes
605
    ssize_t nbytes() const {
606
607
608
609
        return size() * itemsize();
    }

    /// Number of dimensions
610
611
    ssize_t ndim() const {
        return detail::array_proxy(m_ptr)->nd;
612
613
    }

614
615
    /// Base object
    object base() const {
616
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
617
618
    }

619
    /// Dimensions of the array
620
621
    const ssize_t* shape() const {
        return detail::array_proxy(m_ptr)->dimensions;
622
623
624
    }

    /// Dimension along a given axis
625
    ssize_t shape(ssize_t dim) const {
626
        if (dim >= ndim())
627
            fail_dim_check(dim, "invalid axis");
628
629
630
631
        return shape()[dim];
    }

    /// Strides of the array
632
    const ssize_t* strides() const {
633
        return detail::array_proxy(m_ptr)->strides;
634
635
636
    }

    /// Stride along a given axis
637
    ssize_t strides(ssize_t dim) const {
638
        if (dim >= ndim())
639
            fail_dim_check(dim, "invalid axis");
640
641
642
        return strides()[dim];
    }

643
644
    /// Return the NumPy array flags
    int flags() const {
645
        return detail::array_proxy(m_ptr)->flags;
646
647
    }

648
649
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
650
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
651
652
653
654
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
655
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
656
657
    }

658
659
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
660
    template<typename... Ix> const void* data(Ix... index) const {
661
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
662
663
    }

664
665
666
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
667
    template<typename... Ix> void* mutable_data(Ix... index) {
668
        check_writeable();
669
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
670
671
672
673
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
674
    template<typename... Ix> ssize_t offset_at(Ix... index) const {
675
        if ((ssize_t) sizeof...(index) > ndim())
676
            fail_dim_check(sizeof...(index), "too many indices for an array");
677
        return byte_offset(ssize_t(index)...);
678
679
    }

680
    ssize_t offset_at() const { return 0; }
681
682
683

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
684
    template<typename... Ix> ssize_t index_at(Ix... index) const {
685
        return offset_at(index...) / itemsize();
686
687
    }

688
689
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
690
691
692
693
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
694
    template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
695
        if (Dims >= 0 && ndim() != Dims)
696
697
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
698
        return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
699
700
    }

701
702
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
703
704
705
706
707
     * dimensionality checking.  Unlike `mutable_unchecked()`, this does not require that the
     * underlying array have the `writable` flag.  Use with care: the array must not be destroyed or
     * reshaped for the duration of the returned object, and the caller must take care not to access
     * invalid dimensions or dimension indices.
     */
708
    template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
709
        if (Dims >= 0 && ndim() != Dims)
710
711
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
712
        return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
713
714
    }

715
716
717
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
718
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
719
720
    }

uentity's avatar
uentity committed
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    /// Resize array to given shape
    /// If refcheck is true and more that one reference exist to this array
    /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
    void resize(ShapeContainer new_shape, bool refcheck = true) {
        detail::npy_api::PyArray_Dims d = {
            new_shape->data(), int(new_shape->size())
        };
        // try to resize, set ordering param to -1 cause it's not used anyway
        object new_array = reinterpret_steal<object>(
            detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
        );
        if (!new_array) throw error_already_set();
        if (isinstance<array>(new_array)) { *this = std::move(new_array); }
    }

736
    /// Ensure that the argument is a NumPy array
737
738
739
740
741
742
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
743
744
    }

745
protected:
746
747
    template<typename, typename> friend struct detail::npy_format_descriptor;

748
    void fail_dim_check(ssize_t dim, const std::string& msg) const {
749
750
751
752
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

753
    template<typename... Ix> ssize_t byte_offset(Ix... index) const {
754
        check_dimensions(index...);
755
        return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
756
757
    }

758
759
    void check_writeable() const {
        if (!writeable())
760
            throw std::domain_error("array is not writeable");
761
    }
762

763
764
    // Default, C-style strides
    static std::vector<ssize_t> c_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
765
        auto ndim = shape.size();
766
767
768
769
770
771
772
773
774
775
776
777
        std::vector<ssize_t> strides(ndim, itemsize);
        for (size_t i = ndim - 1; i > 0; --i)
            strides[i - 1] = strides[i] * shape[i];
        return strides;
    }

    // F-style strides; default when constructing an array_t with `ExtraFlags & f_style`
    static std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
        auto ndim = shape.size();
        std::vector<ssize_t> strides(ndim, itemsize);
        for (size_t i = 1; i < ndim; ++i)
            strides[i] = strides[i - 1] * shape[i - 1];
778
779
        return strides;
    }
780
781

    template<typename... Ix> void check_dimensions(Ix... index) const {
782
        check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
783
784
    }

785
    void check_dimensions_impl(ssize_t, const ssize_t*) const { }
786

787
    template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
788
789
790
791
792
793
794
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
795
796
797

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
798
799
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
800
            return nullptr;
801
        }
802
        return detail::npy_api::get().PyArray_FromAny_(
803
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
804
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
805
806
};

807
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
808
809
810
811
812
private:
    struct private_ctor {};
    // Delegating constructor needed when both moving and accessing in the same constructor
    array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
        : array(std::move(shape), std::move(strides), ptr, base) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
813
public:
814
815
    static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");

816
817
    using value_type = T;

818
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
819
820
    array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
    array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
821

822
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
823
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
824
825
826
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
827

828
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
829
830
        if (!m_ptr) throw error_already_set();
    }
831

832
    explicit array_t(const buffer_info& info) : array(info) { }
833

834
835
    array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), std::move(strides), ptr, base) { }
836

837
    explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
838
839
840
        : array_t(private_ctor{}, std::move(shape),
                ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()),
                ptr, base) { }
841

842
843
844
    explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
        : array({count}, {}, ptr, base) { }

845
    constexpr ssize_t itemsize() const {
846
        return sizeof(T);
847
848
    }

849
    template<typename... Ix> ssize_t index_at(Ix... index) const {
850
        return offset_at(index...) / itemsize();
851
852
    }

853
    template<typename... Ix> const T* data(Ix... index) const {
854
855
856
        return static_cast<const T*>(array::data(index...));
    }

857
    template<typename... Ix> T* mutable_data(Ix... index) {
858
859
860
861
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
862
    template<typename... Ix> const T& at(Ix... index) const {
863
864
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
865
        return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
866
867
868
    }

    // Mutable reference to element at a given index
869
    template<typename... Ix> T& mutable_at(Ix... index) {
870
871
        if (sizeof...(index) != ndim())
            fail_dim_check(sizeof...(index), "index dimension mismatch");
872
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
873
    }
874

875
876
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
877
878
879
880
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
881
    template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
882
883
884
        return array::mutable_unchecked<T, Dims>();
    }

885
886
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
887
888
889
890
891
     * dimensionality checking.  Unlike `unchecked()`, this does not require that the underlying
     * array have the `writable` flag.  Use with care: the array must not be destroyed or reshaped
     * for the duration of the returned object, and the caller must take care not to access invalid
     * dimensions or dimension indices.
     */
892
    template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
893
894
895
        return array::unchecked<T, Dims>();
    }

Jason Rhinelander's avatar
Jason Rhinelander committed
896
897
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
898
899
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
900
901
        if (!result)
            PyErr_Clear();
902
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
903
    }
904

Wenzel Jakob's avatar
Wenzel Jakob committed
905
    static bool check_(handle h) {
906
907
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
908
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
909
910
911
912
913
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
914
915
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
916
            return nullptr;
917
        }
918
919
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
920
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
921
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
922
923
};

924
template <typename T>
925
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
926
927
928
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
929
930
931
};

template <size_t N> struct format_descriptor<char[N]> {
932
    static std::string format() { return std::to_string(N) + "s"; }
933
934
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
935
    static std::string format() { return std::to_string(N) + "s"; }
936
937
};

938
939
940
941
942
943
944
945
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

946
947
948
949
950
951
952
953
954
template <typename T>
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
    static std::string format() {
        using detail::_;
        PYBIND11_DESCR extents = _("(") + detail::array_info<T>::extents() + _(")");
        return extents.text() + format_descriptor<detail::remove_all_extents_t<T>>::format();
    }
};

955
NAMESPACE_BEGIN(detail)
956
957
958
959
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

960
961
962
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
963
        value = type::ensure(src);
964
965
966
967
968
969
970
971
972
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

973
974
975
976
977
978
979
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
    static bool compare(const buffer_info& b) {
        return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
    }
};

980
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
981
private:
982
983
984
985
986
987
988
989
990
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,    npy_api::NPY_UINT_,    npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

991
public:
992
993
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

994
    static pybind11::dtype dtype() {
995
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
996
            return reinterpret_borrow<pybind11::dtype>(ptr);
997
        pybind11_fail("Unsupported buffer format!");
998
    }
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
    template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, bool>::value>(_("bool"),
            _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
    }
    template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
                _("float") + _<sizeof(T)*8>(), _("longdouble"));
    }
    template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
    static PYBIND11_DESCR name() {
        return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
1012
                _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
1013
    }
1014
};
1015
1016

#define PYBIND11_DECL_CHAR_FMT \
1017
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
1018
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
1019
1020
1021
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
1022

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
private:
    using base_descr = npy_format_descriptor<typename array_info<T>::type>;
public:
    static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");

    static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
    static pybind11::dtype dtype() {
        list shape;
        array_info<T>::append_extents(shape);
        return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
    }
};

1037
1038
1039
1040
1041
1042
1043
1044
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
    static PYBIND11_DESCR name() { return base_descr::name(); }
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

1045
1046
struct field_descriptor {
    const char *name;
1047
1048
    ssize_t offset;
    ssize_t size;
1049
    std::string format;
1050
    dtype descr;
1051
1052
};

1053
1054
inline PYBIND11_NOINLINE void register_structured_dtype(
    const std::initializer_list<field_descriptor>& fields,
1055
    const std::type_info& tinfo, ssize_t itemsize,
1056
1057
    bool (*direct_converter)(PyObject *, void *&)) {

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

    list names, formats, offsets;
    for (auto field : fields) {
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
    std::vector<field_descriptor> ordered_fields(fields);
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
1083
    ssize_t offset = 0;
1084
    std::ostringstream oss;
1085
1086
1087
1088
1089
1090
    // mark the structure as unaligned with '^', because numpy and C++ don't
    // always agree about alignment (particularly for complex), and we're
    // explicitly listing all our padding. This depends on none of the fields
    // overriding the endianness. Putting the ^ in front of individual fields
    // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
    oss << "^T{";
1091
1092
1093
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
1094
        oss << field.format << ':' << field.name << ':';
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

1113
1114
1115
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

1116
    static PYBIND11_DESCR name() { return make_caster<T>::name(); }
1117

1118
    static pybind11::dtype dtype() {
1119
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
1120
1121
    }

1122
    static std::string format() {
1123
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1124
        return format_str;
1125
1126
    }

1127
1128
1129
    static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
        register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
                                  sizeof(T), &direct_converter);
1130
1131
1132
    }

private:
1133
1134
1135
1136
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
1137

1138
1139
1140
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
1141
            return false;
1142
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1143
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1144
1145
1146
1147
1148
1149
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
1150
1151
};

1152
1153
1154
1155
1156
#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
#else

1157
1158
1159
1160
1161
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1162
    }
1163

1164
1165
1166
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

1167
1168
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
1182
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
1183
1184
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1185
#else
1186
1187
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1188
#endif
1189
1190
1191
1192
1193
1194
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1195
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1196
1197
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1198

1199
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
1200
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1201
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1202

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
#ifdef _MSC_VER
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
        ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})

1224
1225
#endif // __CLION_IDE__

1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
1241
    using container_type = std::vector<ssize_t>;
1242
1243
1244
1245
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
1246

1247
    common_iterator(void* ptr, const container_type& strides, const container_type& shape)
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

1270
template <size_t N> class multi_array_iterator {
1271
public:
1272
    using container_type = std::vector<ssize_t>;
1273

1274
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
1275
                         const container_type &shape)
1276
1277
1278
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

1279
        // Manual copy to avoid conversion warning if using std::copy
1280
        for (size_t i = 0; i < shape.size(); ++i)
1281
            m_shape[i] = shape[i];
1282

1283
        container_type strides(shape.size());
1284
        for (size_t i = 0; i < N; ++i)
1285
1286
1287
1288
1289
1290
1291
1292
1293
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
1294
            } else {
1295
1296
1297
1298
1299
1300
                m_index[i] = 0;
            }
        }
        return *this;
    }

1301
1302
    template <size_t K, class T = void> T* data() const {
        return reinterpret_cast<T*>(m_common_iterator[K].data());
1303
1304
1305
1306
1307
1308
    }

private:

    using common_iter = common_iterator;

1309
    void init_common_iterator(const buffer_info &buffer,
1310
1311
1312
                              const container_type &shape,
                              common_iter &iterator,
                              container_type &strides) {
1313
1314
1315
1316
1317
1318
1319
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1320
                *strides_iter = *buffer_strides_iter;
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1335
        for (auto &iter : m_common_iterator)
1336
1337
1338
1339
1340
1341
1342
1343
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

1344
1345
1346
1347
1348
1349
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };

// Populates the shape and number of dimensions for the set of buffers.  Returns a broadcast_trivial
// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
// buffer; returns `non_trivial` otherwise.
1350
template <size_t N>
1351
broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1352
    ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1353
1354
1355
        return std::max(res, buf.ndim);
    });

1356
    shape.clear();
1357
    shape.resize((size_t) ndim, 1);
1358

1359
1360
    // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
    // the full size).
1361
1362
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
1363
1364
        auto end = buffers[i].shape.rend();
        for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
1365
1366
1367
1368
1369
1370
1371
            const auto &dim_size_in = *shape_iter;
            auto &dim_size_out = *res_iter;

            // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
            if (dim_size_out == 1)
                dim_size_out = dim_size_in;
            else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1372
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1373
1374
        }
    }
1375

1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
    bool trivial_broadcast_c = true;
    bool trivial_broadcast_f = true;
    for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
        if (buffers[i].size == 1)
            continue;

        // Require the same number of dimensions:
        if (buffers[i].ndim != ndim)
            return broadcast_trivial::non_trivial;

        // Require all dimensions be full-size:
        if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
            return broadcast_trivial::non_trivial;

        // Check for C contiguity (but only if previous inputs were also C contiguous)
        if (trivial_broadcast_c) {
1392
            ssize_t expect_stride = buffers[i].itemsize;
1393
            auto end = buffers[i].shape.crend();
1394
1395
            for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
                    trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
1396
1397
1398
1399
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_c = false;
1400
            }
1401
        }
1402

1403
1404
        // Check for Fortran contiguity (if previous inputs were also F contiguous)
        if (trivial_broadcast_f) {
1405
            ssize_t expect_stride = buffers[i].itemsize;
1406
            auto end = buffers[i].shape.cend();
1407
1408
            for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
                    trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
1409
1410
1411
1412
1413
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_f = false;
            }
1414
1415
        }
    }
1416
1417
1418
1419
1420

    return
        trivial_broadcast_c ? broadcast_trivial::c_trivial :
        trivial_broadcast_f ? broadcast_trivial::f_trivial :
        broadcast_trivial::non_trivial;
1421
1422
}

1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
template <typename T>
struct vectorize_arg {
    static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
    // The wrapped function gets called with this type:
    using call_type = remove_reference_t<T>;
    // Is this a vectorized argument?
    static constexpr bool vectorize =
        satisfies_any_of<call_type, std::is_arithmetic, is_complex, std::is_pod>::value &&
        satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
        (!std::is_reference<T>::value ||
         (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
    // Accept this type: an array for vectorized types, otherwise the type as-is:
    using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
};

1438
1439
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
1440
private:
1441
    static constexpr size_t N = sizeof...(Args);
1442
1443
1444
    static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
    static_assert(NVectorized >= 1,
            "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1445

1446
public:
1447
    template <typename T>
1448
    explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1449

1450
1451
1452
1453
1454
    object operator()(typename vectorize_arg<Args>::type... args) {
        return run(args...,
                   make_index_sequence<N>(),
                   select_indices<vectorize_arg<Args>::vectorize...>(),
                   make_index_sequence<NVectorized>());
1455
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1456

1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
private:
    remove_reference_t<Func> f;

    template <size_t Index> using param_n_t = typename pack_element<Index, typename vectorize_arg<Args>::call_type...>::type;

    // Runs a vectorized function given arguments tuple and three index sequences:
    //     - Index is the full set of 0 ... (N-1) argument indices;
    //     - VIndex is the subset of argument indices with vectorized parameters, letting us access
    //       vectorized arguments (anything not in this sequence is passed through)
    //     - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
    //       we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
    //       index BIndex in the array).
    template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
            typename vectorize_arg<Args>::type &...args,
            index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {

        // Pointers to values the function was called with; the vectorized ones set here will start
        // out as array_t<T> pointers, but they will be changed them to T pointers before we make
        // call the wrapped function.  Non-vectorized pointers are left as-is.
        std::array<void *, N> params{{ &args... }};

        // The array of `buffer_info`s of vectorized arguments:
        std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
Wenzel Jakob's avatar
Wenzel Jakob committed
1480
1481

        /* Determine dimensions parameters of output array */
1482
1483
1484
1485
        ssize_t nd = 0;
        std::vector<ssize_t> shape(0);
        auto trivial = broadcast(buffers, nd, shape);
        size_t ndim = (size_t) nd;
1486

1487
1488
1489
1490
1491
1492
1493
        size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());

        // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
        // not wrapped in an array).
        if (size == 1 && ndim == 0) {
            PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
            return cast(f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1494
1495
        }

1496
1497
1498
        array_t<Return> result;
        if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
        else result = array_t<Return>(shape);
Wenzel Jakob's avatar
Wenzel Jakob committed
1499

1500
        if (size == 0) return result;
1501

1502
        /* Call the function */
1503
1504
1505
1506
        if (trivial == broadcast_trivial::non_trivial)
            apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq);
        else
            apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
1507
1508

        return result;
1509
    }
1510

1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
                       std::array<void *, N> &params,
                       Return *out,
                       size_t size,
                       index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {

        // Initialize an array of mutable byte references and sizes with references set to the
        // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
        // (except for singletons, which get an increment of 0).
        std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
            std::pair<unsigned char *&, const size_t>(
                    reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
                    buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
            )...
        }};

        for (size_t i = 0; i < size; ++i) {
            out[i] = f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...);
            for (auto &x : vecparams) x.first += x.second;
        }
    }

    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
                         std::array<void *, N> &params,
                         array_t<Return> &output_array,
                         index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
1539

1540
1541
        buffer_info output = output_array.request();
        multi_array_iterator<NVectorized> input_iter(buffers, output.shape);
1542

1543
1544
1545
1546
1547
1548
1549
        for (array_iterator<Return> iter = array_begin<Return>(output), end = array_end<Return>(output);
             iter != end;
             ++iter, ++input_iter) {
            PYBIND11_EXPAND_SIDE_EFFECTS((
                params[VIndex] = input_iter.template data<BIndex>()
            ));
            *iter = f(*reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
1550
1551
        }
    }
1552
1553
};

1554
1555
1556
1557
1558
1559
template <typename Func, typename Return, typename... Args>
vectorize_helper<Func, Return, Args...>
vectorize_extractor(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
}

1560
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1561
1562
1563
    static PYBIND11_DESCR name() {
        return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
    }
1564
1565
};

1566
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1567

1568
// Vanilla pointer vectorizer:
1569
template <typename Return, typename... Args>
1570
detail::vectorize_helper<Return (*)(Args...), Return, Args...>
1571
vectorize(Return (*f) (Args ...)) {
1572
    return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1573
1574
}

1575
// lambda vectorizer:
1576
template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
1577
auto vectorize(Func &&f) -> decltype(
1578
1579
        detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
    return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
}

// Vectorize a class method (non-const):
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...)) {
    return Helper(std::mem_fn(f));
}

// Vectorize a class method (non-const):
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...) const) {
    return Helper(std::mem_fn(f));
Wenzel Jakob's avatar
Wenzel Jakob committed
1594
1595
}

1596
NAMESPACE_END(PYBIND11_NAMESPACE)
Wenzel Jakob's avatar
Wenzel Jakob committed
1597
1598
1599
1600

#if defined(_MSC_VER)
#pragma warning(pop)
#endif