numpy.h 68.5 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdint>
18
#include <cstdlib>
19
#include <cstring>
20
#include <sstream>
21
#include <string>
22
#include <functional>
23
#include <type_traits>
24
#include <utility>
25
#include <vector>
26
#include <typeindex>
27

Wenzel Jakob's avatar
Wenzel Jakob committed
28
#if defined(_MSC_VER)
29
30
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
31
32
#endif

33
/* This will be true on all flat address space platforms and allows us to reduce the
34
   whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
35
36
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
37
38
39
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
40

41
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
42
43
44

class array; // Forward declaration

45
PYBIND11_NAMESPACE_BEGIN(detail)
46
47
48

template <> struct handle_type_name<array> { static constexpr auto name = _("numpy.ndarray"); };

49
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

77
78
79
80
81
82
83
84
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

85
86
87
88
89
90
91
92
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

93
94
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
95
96
97
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
98
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
99
100
        return nullptr;
    }
101
102
103
104

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
105
106
};

Ivan Smirnov's avatar
Ivan Smirnov committed
107
108
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
109
110
111
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
112
113
114
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
115
116
117
    return *ptr;
}

118
119
120
121
template <typename T> struct same_size {
    template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
};

122
123
template <typename Concrete> constexpr int platform_lookup() { return -1; }

124
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
125
126
127
template <typename Concrete, typename T, typename... Ts, typename... Ints>
constexpr int platform_lookup(int I, Ints... Is) {
    return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
128
129
}

130
131
struct npy_api {
    enum constants {
132
133
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
134
        NPY_ARRAY_OWNDATA_ = 0x0004,
135
        NPY_ARRAY_FORCECAST_ = 0x0010,
136
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
137
138
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
139
140
141
142
143
144
145
146
147
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
        // Platform-dependent normalization
        NPY_INT8_ = NPY_BYTE_,
        NPY_UINT8_ = NPY_UBYTE_,
        NPY_INT16_ = NPY_SHORT_,
        NPY_UINT16_ = NPY_USHORT_,
        // `npy_common.h` defines the integer aliases. In order, it checks:
        // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
        // and assigns the alias to the first matching size, so we should check in this order.
        NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
            NPY_LONG_, NPY_INT_, NPY_SHORT_),
        NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
            NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
        NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
            NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
        NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
            NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
165
166
    };

167
    struct PyArray_Dims {
uentity's avatar
uentity committed
168
169
        Py_intptr_t *ptr;
        int len;
170
    };
uentity's avatar
uentity committed
171

172
173
174
175
176
    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

177
178
179
180
181
182
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
183

184
    unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
185
186
    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
187
188
189
        (PyTypeObject *, PyObject *, int, Py_intptr_t const *,
         Py_intptr_t const *, void *, int, PyObject *);
    // Unused. Not removed because that affects ABI of the class.
190
    PyObject *(*PyArray_DescrNewFromType_)(int);
191
    int (*PyArray_CopyInto_)(PyObject *, PyObject *);
192
193
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
194
    PyTypeObject *PyVoidArrType_Type_;
195
    PyTypeObject *PyArrayDescr_Type_;
196
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
197
198
199
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
200
201
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, unsigned char, PyObject **, int *,
                                             Py_intptr_t *, PyObject **, PyObject *);
202
    PyObject *(*PyArray_Squeeze_)(PyObject *);
203
    // Unused. Not removed because that affects ABI of the class.
Jason Rhinelander's avatar
Jason Rhinelander committed
204
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
uentity's avatar
uentity committed
205
    PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
206
207
private:
    enum functions {
208
        API_PyArray_GetNDArrayCFeatureVersion = 211,
209
        API_PyArray_Type = 2,
210
        API_PyArrayDescr_Type = 3,
211
        API_PyVoidArrType_Type = 39,
212
        API_PyArray_DescrFromType = 45,
213
        API_PyArray_DescrFromScalar = 57,
214
        API_PyArray_FromAny = 69,
uentity's avatar
uentity committed
215
        API_PyArray_Resize = 80,
216
        API_PyArray_CopyInto = 82,
217
218
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
219
        API_PyArray_DescrNewFromType = 96,
220
221
222
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
223
224
        API_PyArray_Squeeze = 136,
        API_PyArray_SetBaseObject = 282
225
226
227
    };

    static npy_api lookup() {
228
        module_ m = module_::import("numpy.core.multiarray");
229
        auto c = m.attr("_ARRAY_API");
230
#if PY_MAJOR_VERSION >= 3
231
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
232
#else
233
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
234
#endif
235
        npy_api api;
236
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
237
238
239
        DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
        if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
            pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
240
        DECL_NPY_API(PyArray_Type);
241
        DECL_NPY_API(PyVoidArrType_Type);
242
        DECL_NPY_API(PyArrayDescr_Type);
243
        DECL_NPY_API(PyArray_DescrFromType);
244
        DECL_NPY_API(PyArray_DescrFromScalar);
245
        DECL_NPY_API(PyArray_FromAny);
uentity's avatar
uentity committed
246
        DECL_NPY_API(PyArray_Resize);
247
        DECL_NPY_API(PyArray_CopyInto);
248
249
250
251
252
253
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
254
        DECL_NPY_API(PyArray_Squeeze);
Jason Rhinelander's avatar
Jason Rhinelander committed
255
        DECL_NPY_API(PyArray_SetBaseObject);
256
#undef DECL_NPY_API
257
258
259
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

281
282
283
284
285
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

286
template <typename T> struct array_info_scalar {
287
    using type = T;
288
289
    static constexpr bool is_array = false;
    static constexpr bool is_empty = false;
290
    static constexpr auto extents = _("");
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    static void append_extents(list& /* shape */) { }
};
// Computes underlying type and a comma-separated list of extents for array
// types (any mix of std::array and built-in arrays). An array of char is
// treated as scalar because it gets special handling.
template <typename T> struct array_info : array_info_scalar<T> { };
template <typename T, size_t N> struct array_info<std::array<T, N>> {
    using type = typename array_info<T>::type;
    static constexpr bool is_array = true;
    static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
    static constexpr size_t extent = N;

    // appends the extents to shape
    static void append_extents(list& shape) {
        shape.append(N);
        array_info<T>::append_extents(shape);
    }

309
310
311
    static constexpr auto extents = _<array_info<T>::is_array>(
        concat(_<N>(), array_info<T>::extents), _<N>()
    );
312
313
314
315
316
317
318
319
};
// For numpy we have special handling for arrays of characters, so we don't include
// the size in the array extents.
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
template <typename T> using remove_all_extents_t = typename array_info<T>::type;

320
template <typename T> using is_pod_struct = all_of<
321
    std::is_standard_layout<T>,     // since we're accessing directly in memory we need a standard layout type
322
323
324
#if defined(__GLIBCXX__) && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150623 || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
    // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after 5)
    // don't implement is_trivially_copyable, so approximate it
325
326
    std::is_trivially_destructible<T>,
    satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
327
328
#else
    std::is_trivially_copyable<T>,
329
#endif
330
331
332
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

333
334
335
336
337
338
// Replacement for std::is_pod (deprecated in C++20)
template <typename T> using is_pod = all_of<
    std::is_standard_layout<T>,
    std::is_trivial<T>
>;

339
340
341
342
template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
    return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
343
344
}

345
346
/**
 * Proxy class providing unsafe, unchecked const access to array data.  This is constructed through
347
348
 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.  `Dims`
 * will be -1 for dimensions determined at runtime.
349
 */
350
template <typename T, ssize_t Dims>
351
352
class unchecked_reference {
protected:
353
    static constexpr bool Dynamic = Dims < 0;
354
355
    const unsigned char *data_;
    // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
356
    // make large performance gains on big, nested loops, but requires compile-time dimensions
357
358
359
    conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
            shape_, strides_;
    const ssize_t dims_;
360
361

    friend class pybind11::array;
362
363
    // Constructor for compile-time dimensions:
    template <bool Dyn = Dynamic>
364
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
365
    : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
366
        for (size_t i = 0; i < (size_t) dims_; i++) {
367
368
369
370
            shape_[i] = shape[i];
            strides_[i] = strides[i];
        }
    }
371
372
    // Constructor for runtime dimensions:
    template <bool Dyn = Dynamic>
373
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
374
    : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
375
376

public:
377
378
    /**
     * Unchecked const reference access to data at the given indices.  For a compile-time known
379
380
     * number of dimensions, this requires the correct number of arguments; for run-time
     * dimensionality, this is not checked (and so is up to the caller to use safely).
381
     */
382
    template <typename... Ix> const T &operator()(Ix... index) const {
Jason Rhinelander's avatar
Jason Rhinelander committed
383
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
384
                "Invalid number of indices for unchecked array reference");
385
        return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
386
    }
387
388
    /**
     * Unchecked const reference access to data; this operator only participates if the reference
389
390
     * is to a 1-dimensional array.  When present, this is exactly equivalent to `obj(index)`.
     */
391
392
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    const T &operator[](ssize_t index) const { return operator()(index); }
393

394
    /// Pointer access to the data at the given indices.
395
    template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
396
397

    /// Returns the item size, i.e. sizeof(T)
398
    constexpr static ssize_t itemsize() { return sizeof(T); }
399

400
    /// Returns the shape (i.e. size) of dimension `dim`
401
    ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
402
403

    /// Returns the number of dimensions of the array
404
    ssize_t ndim() const { return dims_; }
405
406
407

    /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
    template <bool Dyn = Dynamic>
408
409
    enable_if_t<!Dyn, ssize_t> size() const {
        return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
410
411
    }
    template <bool Dyn = Dynamic>
412
413
    enable_if_t<Dyn, ssize_t> size() const {
        return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
414
415
416
417
    }

    /// Returns the total number of bytes used by the referenced data.  Note that the actual span in
    /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
418
    ssize_t nbytes() const {
419
420
        return size() * itemsize();
    }
421
422
};

423
template <typename T, ssize_t Dims>
424
425
426
427
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
    friend class pybind11::array;
    using ConstBase = unchecked_reference<T, Dims>;
    using ConstBase::ConstBase;
428
    using ConstBase::Dynamic;
429
public:
430
431
432
433
    // Bring in const-qualified versions from base class
    using ConstBase::operator();
    using ConstBase::operator[];

434
435
    /// Mutable, unchecked access to data at the given indices.
    template <typename... Ix> T& operator()(Ix... index) {
Jason Rhinelander's avatar
Jason Rhinelander committed
436
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
437
                "Invalid number of indices for unchecked array reference");
438
439
        return const_cast<T &>(ConstBase::operator()(index...));
    }
440
441
    /**
     * Mutable, unchecked access data at the given index; this operator only participates if the
442
443
     * reference is to a 1-dimensional array (or has runtime dimensions).  When present, this is
     * exactly equivalent to `obj(index)`.
444
     */
445
446
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    T &operator[](ssize_t index) { return operator()(index); }
447
448

    /// Mutable pointer access to the data at the given indices.
449
    template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
450
451
};

452
template <typename T, ssize_t Dim>
453
struct type_caster<unchecked_reference<T, Dim>> {
454
    static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
455
};
456
template <typename T, ssize_t Dim>
457
458
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};

459
PYBIND11_NAMESPACE_END(detail)
460

461
class dtype : public object {
462
public:
463
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
464

465
    explicit dtype(const buffer_info &info) {
466
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
467
468
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
469
    }
470

471
    explicit dtype(const std::string &format) {
472
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
473
474
    }

475
    dtype(const char *format) : dtype(std::string(format)) { }
476

477
    dtype(list names, list formats, list offsets, ssize_t itemsize) {
478
479
480
481
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
482
        args["itemsize"] = pybind11::int_(itemsize);
483
484
485
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
486
    /// This is essentially the same as calling numpy.dtype(args) in Python.
487
488
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
Jason Rhinelander's avatar
Jason Rhinelander committed
489
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr)
490
            throw error_already_set();
491
        return reinterpret_steal<dtype>(ptr);
492
    }
493

Ivan Smirnov's avatar
Ivan Smirnov committed
494
    /// Return dtype associated with a C++ type.
495
    template <typename T> static dtype of() {
496
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
497
    }
498

Ivan Smirnov's avatar
Ivan Smirnov committed
499
    /// Size of the data type in bytes.
500
501
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
502
503
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
504
    /// Returns true for structured data types.
505
    bool has_fields() const {
506
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
507
508
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
509
510
    /// Single-character code for dtype's kind.
    /// For example, floating point types are 'f' and integral types are 'i'.
511
    char kind() const {
512
        return detail::array_descriptor_proxy(m_ptr)->kind;
513
514
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
515
516
517
518
519
520
521
522
523
    /// Single-character for dtype's type.
    /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'd'.
    char char_() const {
        // Note: The signature, `dtype::char_` follows the naming of NumPy's
        // public Python API (i.e., ``dtype.char``), rather than its internal
        // C API (``PyArray_Descr::type``).
        return detail::array_descriptor_proxy(m_ptr)->type;
    }

524
private:
525
    static object _dtype_from_pep3118() {
526
        static PyObject *obj = module_::import("numpy.core._internal")
527
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
528
        return reinterpret_borrow<object>(obj);
529
    }
530

531
    dtype strip_padding(ssize_t itemsize) {
532
533
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
534
        if (!has_fields())
535
            return *this;
536

537
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
538
539
        std::vector<field_descr> field_descriptors;

540
        for (auto field : attr("fields").attr("items")()) {
541
            auto spec = field.cast<tuple>();
542
            auto name = spec[0].cast<pybind11::str>();
543
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
544
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
545
            if (!len(name) && format.kind() == 'V')
546
                continue;
547
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
548
549
550
551
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
552
                      return a.offset.cast<int>() < b.offset.cast<int>();
553
554
555
556
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
557
558
559
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
560
        }
561
        return dtype(names, formats, offsets, itemsize);
562
563
    }
};
564

565
566
class array : public buffer {
public:
567
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
568
569

    enum {
570
571
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
572
573
574
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

575
    array() : array(0, static_cast<const double *>(nullptr)) {}
576

577
578
    using ShapeContainer = detail::any_container<ssize_t>;
    using StridesContainer = detail::any_container<ssize_t>;
579
580
581
582
583
584

    // Constructs an array taking shape/strides from arbitrary container types
    array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
          const void *ptr = nullptr, handle base = handle()) {

        if (strides->empty())
585
            *strides = detail::c_strides(*shape, dt.itemsize());
586
587
588

        auto ndim = shape->size();
        if (ndim != strides->size())
589
590
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
591
592
593

        int flags = 0;
        if (base && ptr) {
594
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
595
                /* Copy flags from base (except ownership bit) */
596
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
597
598
599
600
601
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

602
        auto &api = detail::npy_api::get();
603
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
604
605
606
607
            api.PyArray_Type_, descr.release().ptr(), (int) ndim,
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(shape->data()),
            reinterpret_cast<Py_intptr_t*>(strides->data()),
608
            const_cast<void *>(ptr), flags, nullptr));
609
        if (!tmp)
610
            throw error_already_set();
611
612
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
613
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
614
            } else {
615
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
616
617
            }
        }
618
619
620
        m_ptr = tmp.release().ptr();
    }

621
622
    array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
        : array(dt, std::move(shape), {}, ptr, base) { }
623

624
625
626
    template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
    array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
        : array(dt, {{count}}, ptr, base) { }
627

628
629
630
    template <typename T>
    array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
        : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
631

632
    template <typename T>
633
634
    array(ShapeContainer shape, const T *ptr, handle base = handle())
        : array(std::move(shape), {}, ptr, base) { }
635

636
    template <typename T>
637
    explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
638

639
640
    explicit array(const buffer_info &info, handle base = handle())
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) { }
641

642
643
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
644
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
645
646
647
    }

    /// Total number of elements
648
649
    ssize_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
650
651
652
    }

    /// Byte size of a single element
653
654
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
655
656
657
    }

    /// Total number of bytes
658
    ssize_t nbytes() const {
659
660
661
662
        return size() * itemsize();
    }

    /// Number of dimensions
663
664
    ssize_t ndim() const {
        return detail::array_proxy(m_ptr)->nd;
665
666
    }

667
668
    /// Base object
    object base() const {
669
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
670
671
    }

672
    /// Dimensions of the array
673
674
    const ssize_t* shape() const {
        return detail::array_proxy(m_ptr)->dimensions;
675
676
677
    }

    /// Dimension along a given axis
678
    ssize_t shape(ssize_t dim) const {
679
        if (dim >= ndim())
680
            fail_dim_check(dim, "invalid axis");
681
682
683
684
        return shape()[dim];
    }

    /// Strides of the array
685
    const ssize_t* strides() const {
686
        return detail::array_proxy(m_ptr)->strides;
687
688
689
    }

    /// Stride along a given axis
690
    ssize_t strides(ssize_t dim) const {
691
        if (dim >= ndim())
692
            fail_dim_check(dim, "invalid axis");
693
694
695
        return strides()[dim];
    }

696
697
    /// Return the NumPy array flags
    int flags() const {
698
        return detail::array_proxy(m_ptr)->flags;
699
700
    }

701
702
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
703
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
704
705
706
707
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
708
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
709
710
    }

711
712
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
713
    template<typename... Ix> const void* data(Ix... index) const {
714
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
715
716
    }

717
718
719
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
720
    template<typename... Ix> void* mutable_data(Ix... index) {
721
        check_writeable();
722
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
723
724
725
726
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
727
    template<typename... Ix> ssize_t offset_at(Ix... index) const {
728
        if ((ssize_t) sizeof...(index) > ndim())
729
            fail_dim_check(sizeof...(index), "too many indices for an array");
730
        return byte_offset(ssize_t(index)...);
731
732
    }

733
    ssize_t offset_at() const { return 0; }
734
735
736

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
737
    template<typename... Ix> ssize_t index_at(Ix... index) const {
738
        return offset_at(index...) / itemsize();
739
740
    }

741
742
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
743
744
745
746
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
747
    template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
748
        if (Dims >= 0 && ndim() != Dims)
749
750
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
751
        return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
752
753
    }

754
755
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
756
757
758
759
760
     * dimensionality checking.  Unlike `mutable_unchecked()`, this does not require that the
     * underlying array have the `writable` flag.  Use with care: the array must not be destroyed or
     * reshaped for the duration of the returned object, and the caller must take care not to access
     * invalid dimensions or dimension indices.
     */
761
    template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
762
        if (Dims >= 0 && ndim() != Dims)
763
764
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
765
        return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
766
767
    }

768
769
770
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
771
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
772
773
    }

uentity's avatar
uentity committed
774
775
776
777
778
    /// Resize array to given shape
    /// If refcheck is true and more that one reference exist to this array
    /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
    void resize(ShapeContainer new_shape, bool refcheck = true) {
        detail::npy_api::PyArray_Dims d = {
779
780
781
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(new_shape->data()),
            int(new_shape->size())
uentity's avatar
uentity committed
782
783
        };
        // try to resize, set ordering param to -1 cause it's not used anyway
784
        auto new_array = reinterpret_steal<object>(
uentity's avatar
uentity committed
785
786
787
788
789
790
            detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
        );
        if (!new_array) throw error_already_set();
        if (isinstance<array>(new_array)) { *this = std::move(new_array); }
    }

791
    /// Ensure that the argument is a NumPy array
792
793
794
795
796
797
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
798
799
    }

800
protected:
801
802
    template<typename, typename> friend struct detail::npy_format_descriptor;

803
    void fail_dim_check(ssize_t dim, const std::string& msg) const {
804
805
806
807
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

808
    template<typename... Ix> ssize_t byte_offset(Ix... index) const {
809
        check_dimensions(index...);
810
        return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
811
812
    }

813
814
    void check_writeable() const {
        if (!writeable())
815
            throw std::domain_error("array is not writeable");
816
    }
817

818
    template<typename... Ix> void check_dimensions(Ix... index) const {
819
        check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
820
821
    }

822
    void check_dimensions_impl(ssize_t, const ssize_t*) const { }
823

824
    template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
825
826
827
828
829
830
831
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
832
833
834

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
835
836
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
837
            return nullptr;
838
        }
839
        return detail::npy_api::get().PyArray_FromAny_(
840
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
841
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
842
843
};

844
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
845
846
847
848
849
private:
    struct private_ctor {};
    // Delegating constructor needed when both moving and accessing in the same constructor
    array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
        : array(std::move(shape), std::move(strides), ptr, base) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
850
public:
851
852
    static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");

853
854
    using value_type = T;

855
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
856
857
    array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
    array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
858

859
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
860
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
861
862
863
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
864

865
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
866
867
        if (!m_ptr) throw error_already_set();
    }
868

869
    explicit array_t(const buffer_info& info, handle base = handle()) : array(info, base) { }
870

871
872
    array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), std::move(strides), ptr, base) { }
873

874
    explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
875
        : array_t(private_ctor{}, std::move(shape),
876
877
878
                ExtraFlags & f_style
                ? detail::f_strides(*shape, itemsize())
                : detail::c_strides(*shape, itemsize()),
879
                ptr, base) { }
880

881
    explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
882
883
        : array({count}, {}, ptr, base) { }

884
    constexpr ssize_t itemsize() const {
885
        return sizeof(T);
886
887
    }

888
    template<typename... Ix> ssize_t index_at(Ix... index) const {
889
        return offset_at(index...) / itemsize();
890
891
    }

892
    template<typename... Ix> const T* data(Ix... index) const {
893
894
895
        return static_cast<const T*>(array::data(index...));
    }

896
    template<typename... Ix> T* mutable_data(Ix... index) {
897
898
899
900
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
901
    template<typename... Ix> const T& at(Ix... index) const {
902
        if ((ssize_t) sizeof...(index) != ndim())
903
            fail_dim_check(sizeof...(index), "index dimension mismatch");
904
        return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
905
906
907
    }

    // Mutable reference to element at a given index
908
    template<typename... Ix> T& mutable_at(Ix... index) {
909
        if ((ssize_t) sizeof...(index) != ndim())
910
            fail_dim_check(sizeof...(index), "index dimension mismatch");
911
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
912
    }
913

914
915
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
916
917
918
919
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
920
    template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
921
922
923
        return array::mutable_unchecked<T, Dims>();
    }

924
925
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
926
927
928
929
930
     * dimensionality checking.  Unlike `unchecked()`, this does not require that the underlying
     * array have the `writable` flag.  Use with care: the array must not be destroyed or reshaped
     * for the duration of the returned object, and the caller must take care not to access invalid
     * dimensions or dimension indices.
     */
931
    template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
932
933
934
        return array::unchecked<T, Dims>();
    }

Jason Rhinelander's avatar
Jason Rhinelander committed
935
936
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
937
938
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
939
940
        if (!result)
            PyErr_Clear();
941
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
942
    }
943

Wenzel Jakob's avatar
Wenzel Jakob committed
944
    static bool check_(handle h) {
945
946
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
947
948
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
               && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
949
950
951
952
953
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
954
955
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
956
            return nullptr;
957
        }
958
959
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
960
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
961
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
962
963
};

964
template <typename T>
965
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
966
967
968
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
969
970
971
};

template <size_t N> struct format_descriptor<char[N]> {
972
    static std::string format() { return std::to_string(N) + "s"; }
973
974
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
975
    static std::string format() { return std::to_string(N) + "s"; }
976
977
};

978
979
980
981
982
983
984
985
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

986
987
988
template <typename T>
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
    static std::string format() {
989
990
991
        using namespace detail;
        static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
        return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
992
993
994
    }
};

995
PYBIND11_NAMESPACE_BEGIN(detail)
996
997
998
999
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

1000
1001
1002
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
1003
        value = type::ensure(src);
1004
1005
1006
1007
1008
1009
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
1010
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
1011
1012
};

1013
1014
1015
1016
1017
1018
1019
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
    static bool compare(const buffer_info& b) {
        return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
    }
};

1020
1021
1022
1023
1024
1025
template <typename T, typename = void>
struct npy_format_descriptor_name;

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
    static constexpr auto name = _<std::is_same<T, bool>::value>(
1026
        _("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
1027
1028
1029
1030
1031
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
1032
1033
1034
1035
    static constexpr auto name = _<std::is_same<T, float>::value
                                   || std::is_same<T, const float>::value
                                   || std::is_same<T, double>::value
                                   || std::is_same<T, const double>::value>(
1036
        _("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
1037
1038
1039
1040
1041
1042
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
    static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
1043
1044
1045
                                   || std::is_same<typename T::value_type, const float>::value
                                   || std::is_same<typename T::value_type, double>::value
                                   || std::is_same<typename T::value_type, const double>::value>(
1046
        _("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
1047
1048
1049
1050
1051
1052
    );
};

template <typename T>
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
    : npy_format_descriptor_name<T> {
1053
private:
1054
1055
1056
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
1057
1058
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_INT16_,    npy_api::NPY_UINT16_,
        npy_api::NPY_INT32_,  npy_api::NPY_UINT32_,  npy_api::NPY_INT64_,    npy_api::NPY_UINT64_,
1059
1060
1061
1062
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

1063
public:
1064
1065
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

1066
    static pybind11::dtype dtype() {
1067
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
1068
            return reinterpret_steal<pybind11::dtype>(ptr);
1069
        pybind11_fail("Unsupported buffer format!");
1070
    }
1071
};
1072
1073

#define PYBIND11_DECL_CHAR_FMT \
1074
    static constexpr auto name = _("S") + _<N>(); \
1075
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
1076
1077
1078
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
1079

1080
1081
1082
1083
1084
1085
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
private:
    using base_descr = npy_format_descriptor<typename array_info<T>::type>;
public:
    static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");

1086
    static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
1087
1088
1089
1090
1091
1092
1093
    static pybind11::dtype dtype() {
        list shape;
        array_info<T>::append_extents(shape);
        return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
    }
};

1094
1095
1096
1097
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
1098
    static constexpr auto name = base_descr::name;
1099
1100
1101
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

1102
1103
struct field_descriptor {
    const char *name;
1104
1105
    ssize_t offset;
    ssize_t size;
1106
    std::string format;
1107
    dtype descr;
1108
1109
};

1110
inline PYBIND11_NOINLINE void register_structured_dtype(
1111
    any_container<field_descriptor> fields,
1112
    const std::type_info& tinfo, ssize_t itemsize,
1113
1114
    bool (*direct_converter)(PyObject *, void *&)) {

1115
1116
1117
1118
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

1119
1120
1121
1122
1123
1124
    // Use ordered fields because order matters as of NumPy 1.14:
    // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
    std::vector<field_descriptor> ordered_fields(std::move(fields));
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });

1125
    list names, formats, offsets;
1126
    for (auto& field : ordered_fields) {
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
1143
    ssize_t offset = 0;
1144
    std::ostringstream oss;
1145
1146
1147
1148
1149
1150
    // mark the structure as unaligned with '^', because numpy and C++ don't
    // always agree about alignment (particularly for complex), and we're
    // explicitly listing all our padding. This depends on none of the fields
    // overriding the endianness. Putting the ^ in front of individual fields
    // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
    oss << "^T{";
1151
1152
1153
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
1154
        oss << field.format << ':' << field.name << ':';
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

1173
1174
1175
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

1176
    static constexpr auto name = make_caster<T>::name;
1177

1178
    static pybind11::dtype dtype() {
1179
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
1180
1181
    }

1182
    static std::string format() {
1183
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1184
        return format_str;
1185
1186
    }

1187
1188
    static void register_dtype(any_container<field_descriptor> fields) {
        register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
1189
                                  sizeof(T), &direct_converter);
1190
1191
1192
    }

private:
1193
1194
1195
1196
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
1197

1198
1199
1200
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
1201
            return false;
1202
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1203
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1204
1205
1206
1207
1208
1209
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
1210
1211
};

1212
1213
1214
1215
1216
#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
#else

1217
1218
1219
1220
1221
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1222
    }
1223

1224
1225
1226
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

1227
1228
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
peter's avatar
peter committed
1242
#if defined(_MSC_VER) && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
1243
1244
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1245
#else
1246
1247
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1248
#endif
1249
1250
1251
1252
1253
1254
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1255
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1256
1257
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1258

1259
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
1260
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1261
1262
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1263

peter's avatar
peter committed
1264
#if defined(_MSC_VER) && !defined(__clang__)
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1283
1284
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1285

1286
1287
#endif // __CLION_IDE__

1288
1289
class common_iterator {
public:
1290
    using container_type = std::vector<ssize_t>;
1291
1292
1293
1294
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
1295

1296
    common_iterator(void* ptr, const container_type& strides, const container_type& shape)
1297
1298
1299
1300
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
1301
            auto s = static_cast<value_type>(shape[i]);
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

1319
template <size_t N> class multi_array_iterator {
1320
public:
1321
    using container_type = std::vector<ssize_t>;
1322

1323
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
1324
                         const container_type &shape)
1325
1326
1327
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

1328
        // Manual copy to avoid conversion warning if using std::copy
1329
        for (size_t i = 0; i < shape.size(); ++i)
1330
            m_shape[i] = shape[i];
1331

1332
        container_type strides(shape.size());
1333
        for (size_t i = 0; i < N; ++i)
1334
1335
1336
1337
1338
1339
1340
1341
1342
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
1343
            } else {
1344
1345
1346
1347
1348
1349
                m_index[i] = 0;
            }
        }
        return *this;
    }

1350
1351
    template <size_t K, class T = void> T* data() const {
        return reinterpret_cast<T*>(m_common_iterator[K].data());
1352
1353
1354
1355
1356
1357
    }

private:

    using common_iter = common_iterator;

1358
    void init_common_iterator(const buffer_info &buffer,
1359
1360
1361
                              const container_type &shape,
                              common_iter &iterator,
                              container_type &strides) {
1362
1363
1364
1365
1366
1367
1368
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1369
                *strides_iter = *buffer_strides_iter;
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1384
        for (auto &iter : m_common_iterator)
1385
1386
1387
1388
1389
1390
1391
1392
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

1393
1394
1395
1396
1397
1398
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };

// Populates the shape and number of dimensions for the set of buffers.  Returns a broadcast_trivial
// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
// buffer; returns `non_trivial` otherwise.
1399
template <size_t N>
1400
broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1401
    ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1402
1403
1404
        return std::max(res, buf.ndim);
    });

1405
    shape.clear();
1406
    shape.resize((size_t) ndim, 1);
1407

1408
1409
    // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
    // the full size).
1410
1411
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
1412
1413
        auto end = buffers[i].shape.rend();
        for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
1414
1415
1416
1417
1418
1419
1420
            const auto &dim_size_in = *shape_iter;
            auto &dim_size_out = *res_iter;

            // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
            if (dim_size_out == 1)
                dim_size_out = dim_size_in;
            else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1421
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1422
1423
        }
    }
1424

1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    bool trivial_broadcast_c = true;
    bool trivial_broadcast_f = true;
    for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
        if (buffers[i].size == 1)
            continue;

        // Require the same number of dimensions:
        if (buffers[i].ndim != ndim)
            return broadcast_trivial::non_trivial;

        // Require all dimensions be full-size:
        if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
            return broadcast_trivial::non_trivial;

        // Check for C contiguity (but only if previous inputs were also C contiguous)
        if (trivial_broadcast_c) {
1441
            ssize_t expect_stride = buffers[i].itemsize;
1442
            auto end = buffers[i].shape.crend();
1443
1444
            for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
                    trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
1445
1446
1447
1448
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_c = false;
1449
            }
1450
        }
1451

1452
1453
        // Check for Fortran contiguity (if previous inputs were also F contiguous)
        if (trivial_broadcast_f) {
1454
            ssize_t expect_stride = buffers[i].itemsize;
1455
            auto end = buffers[i].shape.cend();
1456
1457
            for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
                    trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
1458
1459
1460
1461
1462
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_f = false;
            }
1463
1464
        }
    }
1465
1466
1467
1468
1469

    return
        trivial_broadcast_c ? broadcast_trivial::c_trivial :
        trivial_broadcast_f ? broadcast_trivial::f_trivial :
        broadcast_trivial::non_trivial;
1470
1471
}

1472
1473
1474
1475
1476
1477
1478
template <typename T>
struct vectorize_arg {
    static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
    // The wrapped function gets called with this type:
    using call_type = remove_reference_t<T>;
    // Is this a vectorized argument?
    static constexpr bool vectorize =
1479
        satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value &&
1480
1481
1482
1483
1484
1485
1486
        satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
        (!std::is_reference<T>::value ||
         (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
    // Accept this type: an array for vectorized types, otherwise the type as-is:
    using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
};

1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536

// py::vectorize when a return type is present
template <typename Func, typename Return, typename... Args>
struct vectorize_returned_array {
    using Type = array_t<Return>;

    static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
        if (trivial == broadcast_trivial::f_trivial)
            return array_t<Return, array::f_style>(shape);
        else
            return array_t<Return>(shape);
    }

    static Return *mutable_data(Type &array) {
        return array.mutable_data();
    }

    static Return call(Func &f, Args &... args) {
        return f(args...);
    }

    static void call(Return *out, size_t i, Func &f, Args &... args) {
        out[i] = f(args...);
    }
};

// py::vectorize when a return type is not present
template <typename Func, typename... Args>
struct vectorize_returned_array<Func, void, Args...> {
    using Type = none;

    static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
        return none();
    }

    static void *mutable_data(Type &) {
        return nullptr;
    }

    static detail::void_type call(Func &f, Args &... args) {
        f(args...);
        return {};
    }

    static void call(void *, size_t, Func &f, Args &... args) {
        f(args...);
    }
};


1537
1538
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
1539
1540
1541
1542
1543

// NVCC for some reason breaks if NVectorized is private
#ifdef __CUDACC__
public:
#else
1544
private:
1545
1546
#endif

1547
    static constexpr size_t N = sizeof...(Args);
1548
1549
1550
    static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
    static_assert(NVectorized >= 1,
            "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1551

1552
public:
1553
    template <typename T>
1554
    explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1555

1556
1557
1558
1559
1560
    object operator()(typename vectorize_arg<Args>::type... args) {
        return run(args...,
                   make_index_sequence<N>(),
                   select_indices<vectorize_arg<Args>::vectorize...>(),
                   make_index_sequence<NVectorized>());
1561
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1562

1563
1564
1565
private:
    remove_reference_t<Func> f;

1566
1567
1568
1569
    // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
    // when arg_call_types is manually inlined.
    using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
    template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
1570

1571
1572
    using returned_array = vectorize_returned_array<Func, Return, Args...>;

1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
    // Runs a vectorized function given arguments tuple and three index sequences:
    //     - Index is the full set of 0 ... (N-1) argument indices;
    //     - VIndex is the subset of argument indices with vectorized parameters, letting us access
    //       vectorized arguments (anything not in this sequence is passed through)
    //     - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
    //       we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
    //       index BIndex in the array).
    template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
            typename vectorize_arg<Args>::type &...args,
            index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {

        // Pointers to values the function was called with; the vectorized ones set here will start
        // out as array_t<T> pointers, but they will be changed them to T pointers before we make
        // call the wrapped function.  Non-vectorized pointers are left as-is.
        std::array<void *, N> params{{ &args... }};

        // The array of `buffer_info`s of vectorized arguments:
        std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
Wenzel Jakob's avatar
Wenzel Jakob committed
1591
1592

        /* Determine dimensions parameters of output array */
1593
1594
1595
        ssize_t nd = 0;
        std::vector<ssize_t> shape(0);
        auto trivial = broadcast(buffers, nd, shape);
1596
        auto ndim = (size_t) nd;
1597

1598
1599
1600
1601
1602
1603
        size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());

        // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
        // not wrapped in an array).
        if (size == 1 && ndim == 0) {
            PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
1604
            return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1605
1606
        }

1607
        auto result = returned_array::create(trivial, shape);
Wenzel Jakob's avatar
Wenzel Jakob committed
1608

Henry Schreiner's avatar
Henry Schreiner committed
1609
        if (size == 0) return std::move(result);
1610

1611
        /* Call the function */
1612
        auto mutable_data = returned_array::mutable_data(result);
1613
        if (trivial == broadcast_trivial::non_trivial)
1614
            apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
1615
        else
1616
            apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
1617

Henry Schreiner's avatar
Henry Schreiner committed
1618
        return std::move(result);
1619
    }
1620

1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
                       std::array<void *, N> &params,
                       Return *out,
                       size_t size,
                       index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {

        // Initialize an array of mutable byte references and sizes with references set to the
        // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
        // (except for singletons, which get an increment of 0).
        std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
            std::pair<unsigned char *&, const size_t>(
                    reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
                    buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
            )...
        }};

        for (size_t i = 0; i < size; ++i) {
1639
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
1640
1641
1642
1643
1644
1645
1646
            for (auto &x : vecparams) x.first += x.second;
        }
    }

    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
                         std::array<void *, N> &params,
1647
1648
1649
                         Return *out,
                         size_t size,
                         const std::vector<ssize_t> &output_shape,
1650
                         index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
1651

1652
        multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
1653

1654
        for (size_t i = 0; i < size; ++i, ++input_iter) {
1655
1656
1657
            PYBIND11_EXPAND_SIDE_EFFECTS((
                params[VIndex] = input_iter.template data<BIndex>()
            ));
1658
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
1659
1660
        }
    }
1661
1662
};

1663
1664
1665
1666
1667
1668
template <typename Func, typename Return, typename... Args>
vectorize_helper<Func, Return, Args...>
vectorize_extractor(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
}

1669
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1670
    static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
1671
1672
};

1673
PYBIND11_NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1674

1675
// Vanilla pointer vectorizer:
1676
template <typename Return, typename... Args>
1677
detail::vectorize_helper<Return (*)(Args...), Return, Args...>
1678
vectorize(Return (*f) (Args ...)) {
1679
    return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1680
1681
}

1682
// lambda vectorizer:
1683
template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
1684
auto vectorize(Func &&f) -> decltype(
1685
1686
        detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
    return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
1687
1688
1689
1690
1691
1692
1693
1694
1695
}

// Vectorize a class method (non-const):
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...)) {
    return Helper(std::mem_fn(f));
}

1696
// Vectorize a class method (const):
1697
1698
1699
1700
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...) const) {
    return Helper(std::mem_fn(f));
Wenzel Jakob's avatar
Wenzel Jakob committed
1701
1702
}

1703
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
Wenzel Jakob's avatar
Wenzel Jakob committed
1704
1705
1706
1707

#if defined(_MSC_VER)
#pragma warning(pop)
#endif