numpy.h 70.2 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdint>
18
#include <cstdlib>
19
#include <cstring>
20
#include <sstream>
21
#include <string>
22
#include <functional>
23
#include <type_traits>
24
#include <utility>
25
#include <vector>
26
#include <typeindex>
27

28
/* This will be true on all flat address space platforms and allows us to reduce the
29
   whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
30
31
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
32
33
34
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
35

36
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
37
38
39

class array; // Forward declaration

40
PYBIND11_NAMESPACE_BEGIN(detail)
41
42
43

template <> struct handle_type_name<array> { static constexpr auto name = _("numpy.ndarray"); };

44
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

72
73
74
75
76
77
78
79
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

80
81
82
83
84
85
86
87
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

88
89
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
90
91
92
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
93
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
94
95
        return nullptr;
    }
96
97
98
99

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
100
101
};

102
PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
Ivan Smirnov's avatar
Ivan Smirnov committed
103
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
104
105
106
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
107
108
109
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
110
111
112
    return *ptr;
}

113
114
115
116
template <typename T> struct same_size {
    template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
};

117
118
template <typename Concrete> constexpr int platform_lookup() { return -1; }

119
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
120
121
122
template <typename Concrete, typename T, typename... Ts, typename... Ints>
constexpr int platform_lookup(int I, Ints... Is) {
    return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
123
124
}

125
126
struct npy_api {
    enum constants {
127
128
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
129
        NPY_ARRAY_OWNDATA_ = 0x0004,
130
        NPY_ARRAY_FORCECAST_ = 0x0010,
131
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
132
133
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
134
135
136
137
138
139
140
141
142
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
        // Platform-dependent normalization
        NPY_INT8_ = NPY_BYTE_,
        NPY_UINT8_ = NPY_UBYTE_,
        NPY_INT16_ = NPY_SHORT_,
        NPY_UINT16_ = NPY_USHORT_,
        // `npy_common.h` defines the integer aliases. In order, it checks:
        // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
        // and assigns the alias to the first matching size, so we should check in this order.
        NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
            NPY_LONG_, NPY_INT_, NPY_SHORT_),
        NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
            NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
        NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
            NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
        NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
            NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
160
161
    };

162
    struct PyArray_Dims {
uentity's avatar
uentity committed
163
164
        Py_intptr_t *ptr;
        int len;
165
    };
uentity's avatar
uentity committed
166

167
168
169
170
171
    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

172
173
174
175
176
177
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
178

179
    unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
180
181
    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
182
183
184
        (PyTypeObject *, PyObject *, int, Py_intptr_t const *,
         Py_intptr_t const *, void *, int, PyObject *);
    // Unused. Not removed because that affects ABI of the class.
185
    PyObject *(*PyArray_DescrNewFromType_)(int);
186
    int (*PyArray_CopyInto_)(PyObject *, PyObject *);
187
188
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
189
    PyTypeObject *PyVoidArrType_Type_;
190
    PyTypeObject *PyArrayDescr_Type_;
191
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
192
193
194
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
195
196
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, unsigned char, PyObject **, int *,
                                             Py_intptr_t *, PyObject **, PyObject *);
197
    PyObject *(*PyArray_Squeeze_)(PyObject *);
198
    // Unused. Not removed because that affects ABI of the class.
Jason Rhinelander's avatar
Jason Rhinelander committed
199
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
uentity's avatar
uentity committed
200
    PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
Nick Cullen's avatar
Nick Cullen committed
201
    PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
Nick Cullen's avatar
Nick Cullen committed
202
    PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*);
Nick Cullen's avatar
Nick Cullen committed
203

204
205
private:
    enum functions {
206
        API_PyArray_GetNDArrayCFeatureVersion = 211,
207
        API_PyArray_Type = 2,
208
        API_PyArrayDescr_Type = 3,
209
        API_PyVoidArrType_Type = 39,
210
        API_PyArray_DescrFromType = 45,
211
        API_PyArray_DescrFromScalar = 57,
212
        API_PyArray_FromAny = 69,
uentity's avatar
uentity committed
213
        API_PyArray_Resize = 80,
214
        API_PyArray_CopyInto = 82,
215
216
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
217
        API_PyArray_DescrNewFromType = 96,
Nick Cullen's avatar
Nick Cullen committed
218
219
        API_PyArray_Newshape = 135,
        API_PyArray_Squeeze = 136,
Nick Cullen's avatar
Nick Cullen committed
220
        API_PyArray_View = 137,
221
222
223
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
224
        API_PyArray_SetBaseObject = 282
225
226
227
    };

    static npy_api lookup() {
228
        module_ m = module_::import("numpy.core.multiarray");
229
        auto c = m.attr("_ARRAY_API");
230
#if PY_MAJOR_VERSION >= 3
231
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
232
#else
233
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
234
#endif
235
        npy_api api;
236
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
237
238
239
        DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
        if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
            pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
240
        DECL_NPY_API(PyArray_Type);
241
        DECL_NPY_API(PyVoidArrType_Type);
242
        DECL_NPY_API(PyArrayDescr_Type);
243
        DECL_NPY_API(PyArray_DescrFromType);
244
        DECL_NPY_API(PyArray_DescrFromScalar);
245
        DECL_NPY_API(PyArray_FromAny);
uentity's avatar
uentity committed
246
        DECL_NPY_API(PyArray_Resize);
247
        DECL_NPY_API(PyArray_CopyInto);
248
249
250
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
Nick Cullen's avatar
Nick Cullen committed
251
252
        DECL_NPY_API(PyArray_Newshape);
        DECL_NPY_API(PyArray_Squeeze);
Nick Cullen's avatar
Nick Cullen committed
253
        DECL_NPY_API(PyArray_View);
254
255
256
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
Jason Rhinelander's avatar
Jason Rhinelander committed
257
        DECL_NPY_API(PyArray_SetBaseObject);
Nick Cullen's avatar
Nick Cullen committed
258

259
#undef DECL_NPY_API
260
261
262
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

284
285
286
287
288
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

289
template <typename T> struct array_info_scalar {
290
    using type = T;
291
292
    static constexpr bool is_array = false;
    static constexpr bool is_empty = false;
293
    static constexpr auto extents = _("");
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    static void append_extents(list& /* shape */) { }
};
// Computes underlying type and a comma-separated list of extents for array
// types (any mix of std::array and built-in arrays). An array of char is
// treated as scalar because it gets special handling.
template <typename T> struct array_info : array_info_scalar<T> { };
template <typename T, size_t N> struct array_info<std::array<T, N>> {
    using type = typename array_info<T>::type;
    static constexpr bool is_array = true;
    static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
    static constexpr size_t extent = N;

    // appends the extents to shape
    static void append_extents(list& shape) {
        shape.append(N);
        array_info<T>::append_extents(shape);
    }

312
313
314
    static constexpr auto extents = _<array_info<T>::is_array>(
        concat(_<N>(), array_info<T>::extents), _<N>()
    );
315
316
317
318
319
320
321
322
};
// For numpy we have special handling for arrays of characters, so we don't include
// the size in the array extents.
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
template <typename T> using remove_all_extents_t = typename array_info<T>::type;

323
template <typename T> using is_pod_struct = all_of<
324
    std::is_standard_layout<T>,     // since we're accessing directly in memory we need a standard layout type
325
#if defined(__GLIBCXX__) && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
326
327
    // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after 5)
    // don't implement is_trivially_copyable, so approximate it
328
329
    std::is_trivially_destructible<T>,
    satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
330
331
#else
    std::is_trivially_copyable<T>,
332
#endif
333
334
335
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

336
337
338
339
340
341
// Replacement for std::is_pod (deprecated in C++20)
template <typename T> using is_pod = all_of<
    std::is_standard_layout<T>,
    std::is_trivial<T>
>;

342
343
344
345
template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
    return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
346
347
}

348
349
/**
 * Proxy class providing unsafe, unchecked const access to array data.  This is constructed through
350
351
 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.  `Dims`
 * will be -1 for dimensions determined at runtime.
352
 */
353
template <typename T, ssize_t Dims>
354
355
class unchecked_reference {
protected:
356
    static constexpr bool Dynamic = Dims < 0;
357
358
    const unsigned char *data_;
    // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
359
    // make large performance gains on big, nested loops, but requires compile-time dimensions
360
361
362
    conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
            shape_, strides_;
    const ssize_t dims_;
363
364

    friend class pybind11::array;
365
366
    // Constructor for compile-time dimensions:
    template <bool Dyn = Dynamic>
367
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
368
    : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
369
        for (size_t i = 0; i < (size_t) dims_; i++) {
370
371
372
373
            shape_[i] = shape[i];
            strides_[i] = strides[i];
        }
    }
374
375
    // Constructor for runtime dimensions:
    template <bool Dyn = Dynamic>
376
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
377
    : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
378
379

public:
380
381
    /**
     * Unchecked const reference access to data at the given indices.  For a compile-time known
382
383
     * number of dimensions, this requires the correct number of arguments; for run-time
     * dimensionality, this is not checked (and so is up to the caller to use safely).
384
     */
385
    template <typename... Ix> const T &operator()(Ix... index) const {
Jason Rhinelander's avatar
Jason Rhinelander committed
386
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
387
                "Invalid number of indices for unchecked array reference");
388
        return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
389
    }
390
391
    /**
     * Unchecked const reference access to data; this operator only participates if the reference
392
393
     * is to a 1-dimensional array.  When present, this is exactly equivalent to `obj(index)`.
     */
394
395
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    const T &operator[](ssize_t index) const { return operator()(index); }
396

397
    /// Pointer access to the data at the given indices.
398
    template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
399
400

    /// Returns the item size, i.e. sizeof(T)
401
    constexpr static ssize_t itemsize() { return sizeof(T); }
402

403
    /// Returns the shape (i.e. size) of dimension `dim`
404
    ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
405
406

    /// Returns the number of dimensions of the array
407
    ssize_t ndim() const { return dims_; }
408
409
410

    /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
    template <bool Dyn = Dynamic>
411
412
    enable_if_t<!Dyn, ssize_t> size() const {
        return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
413
414
    }
    template <bool Dyn = Dynamic>
415
416
    enable_if_t<Dyn, ssize_t> size() const {
        return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
417
418
419
420
    }

    /// Returns the total number of bytes used by the referenced data.  Note that the actual span in
    /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
421
    ssize_t nbytes() const {
422
423
        return size() * itemsize();
    }
424
425
};

426
template <typename T, ssize_t Dims>
427
428
429
430
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
    friend class pybind11::array;
    using ConstBase = unchecked_reference<T, Dims>;
    using ConstBase::ConstBase;
431
    using ConstBase::Dynamic;
432
public:
433
434
435
436
    // Bring in const-qualified versions from base class
    using ConstBase::operator();
    using ConstBase::operator[];

437
438
    /// Mutable, unchecked access to data at the given indices.
    template <typename... Ix> T& operator()(Ix... index) {
Jason Rhinelander's avatar
Jason Rhinelander committed
439
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
440
                "Invalid number of indices for unchecked array reference");
441
442
        return const_cast<T &>(ConstBase::operator()(index...));
    }
443
444
    /**
     * Mutable, unchecked access data at the given index; this operator only participates if the
445
446
     * reference is to a 1-dimensional array (or has runtime dimensions).  When present, this is
     * exactly equivalent to `obj(index)`.
447
     */
448
449
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    T &operator[](ssize_t index) { return operator()(index); }
450
451

    /// Mutable pointer access to the data at the given indices.
452
    template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
453
454
};

455
template <typename T, ssize_t Dim>
456
struct type_caster<unchecked_reference<T, Dim>> {
457
    static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
458
};
459
template <typename T, ssize_t Dim>
460
461
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};

462
PYBIND11_NAMESPACE_END(detail)
463

464
class dtype : public object {
465
public:
466
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
467

468
    explicit dtype(const buffer_info &info) {
469
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
470
        // If info.itemsize == 0, use the value calculated from the format string
471
472
473
        m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize())
                    .release()
                    .ptr();
474
    }
475

476
    explicit dtype(const std::string &format) {
477
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
478
479
    }

480
    explicit dtype(const char *format) : dtype(std::string(format)) {}
481

482
    dtype(list names, list formats, list offsets, ssize_t itemsize) {
483
        dict args;
484
485
486
        args["names"] = std::move(names);
        args["formats"] = std::move(formats);
        args["offsets"] = std::move(offsets);
487
        args["itemsize"] = pybind11::int_(itemsize);
488
        m_ptr = from_args(std::move(args)).release().ptr();
489
490
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
491
    /// This is essentially the same as calling numpy.dtype(args) in Python.
492
493
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
494
        if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr)
495
            throw error_already_set();
496
        return reinterpret_steal<dtype>(ptr);
497
    }
498

Ivan Smirnov's avatar
Ivan Smirnov committed
499
    /// Return dtype associated with a C++ type.
500
    template <typename T> static dtype of() {
501
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
502
    }
503

Ivan Smirnov's avatar
Ivan Smirnov committed
504
    /// Size of the data type in bytes.
505
506
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
507
508
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
509
    /// Returns true for structured data types.
510
    bool has_fields() const {
511
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
512
513
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
514
515
    /// Single-character code for dtype's kind.
    /// For example, floating point types are 'f' and integral types are 'i'.
516
    char kind() const {
517
        return detail::array_descriptor_proxy(m_ptr)->kind;
518
519
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
520
521
522
523
524
525
526
527
528
    /// Single-character for dtype's type.
    /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'd'.
    char char_() const {
        // Note: The signature, `dtype::char_` follows the naming of NumPy's
        // public Python API (i.e., ``dtype.char``), rather than its internal
        // C API (``PyArray_Descr::type``).
        return detail::array_descriptor_proxy(m_ptr)->type;
    }

529
private:
530
    static object _dtype_from_pep3118() {
531
        static PyObject *obj = module_::import("numpy.core._internal")
532
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
533
        return reinterpret_borrow<object>(obj);
534
    }
535

536
    dtype strip_padding(ssize_t itemsize) {
537
538
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
539
        if (!has_fields())
540
            return *this;
541

542
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
543
544
        std::vector<field_descr> field_descriptors;

545
        for (auto field : attr("fields").attr("items")()) {
546
            auto spec = field.cast<tuple>();
547
            auto name = spec[0].cast<pybind11::str>();
548
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
549
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
550
            if ((len(name) == 0u) && format.kind() == 'V')
551
                continue;
552
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
553
554
555
556
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
557
                      return a.offset.cast<int>() < b.offset.cast<int>();
558
559
560
561
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
562
563
564
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
565
        }
566
        return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize);
567
568
    }
};
569

570
571
class array : public buffer {
public:
572
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
573
574

    enum {
575
576
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
577
578
579
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

580
    array() : array(0, static_cast<const double *>(nullptr)) {}
581

582
583
    using ShapeContainer = detail::any_container<ssize_t>;
    using StridesContainer = detail::any_container<ssize_t>;
584
585
586
587
588
589

    // Constructs an array taking shape/strides from arbitrary container types
    array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
          const void *ptr = nullptr, handle base = handle()) {

        if (strides->empty())
590
            *strides = detail::c_strides(*shape, dt.itemsize());
591
592
593

        auto ndim = shape->size();
        if (ndim != strides->size())
594
595
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
596
597
598

        int flags = 0;
        if (base && ptr) {
599
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
600
                /* Copy flags from base (except ownership bit) */
601
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
602
603
604
605
606
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

607
        auto &api = detail::npy_api::get();
608
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
609
610
611
612
            api.PyArray_Type_, descr.release().ptr(), (int) ndim,
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(shape->data()),
            reinterpret_cast<Py_intptr_t*>(strides->data()),
613
            const_cast<void *>(ptr), flags, nullptr));
614
        if (!tmp)
615
            throw error_already_set();
616
617
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
618
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
619
            } else {
620
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
621
622
            }
        }
623
624
625
        m_ptr = tmp.release().ptr();
    }

626
627
    array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
        : array(dt, std::move(shape), {}, ptr, base) { }
628

629
630
631
    template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
    array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
        : array(dt, {{count}}, ptr, base) { }
632

633
634
635
    template <typename T>
    array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
        : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
636

637
    template <typename T>
638
639
    array(ShapeContainer shape, const T *ptr, handle base = handle())
        : array(std::move(shape), {}, ptr, base) { }
640

641
    template <typename T>
642
    explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
643

644
645
    explicit array(const buffer_info &info, handle base = handle())
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) { }
646

647
648
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
649
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
650
651
652
    }

    /// Total number of elements
653
654
    ssize_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
655
656
657
    }

    /// Byte size of a single element
658
659
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
660
661
662
    }

    /// Total number of bytes
663
    ssize_t nbytes() const {
664
665
666
667
        return size() * itemsize();
    }

    /// Number of dimensions
668
669
    ssize_t ndim() const {
        return detail::array_proxy(m_ptr)->nd;
670
671
    }

672
673
    /// Base object
    object base() const {
674
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
675
676
    }

677
    /// Dimensions of the array
678
679
    const ssize_t* shape() const {
        return detail::array_proxy(m_ptr)->dimensions;
680
681
682
    }

    /// Dimension along a given axis
683
    ssize_t shape(ssize_t dim) const {
684
        if (dim >= ndim())
685
            fail_dim_check(dim, "invalid axis");
686
687
688
689
        return shape()[dim];
    }

    /// Strides of the array
690
    const ssize_t* strides() const {
691
        return detail::array_proxy(m_ptr)->strides;
692
693
694
    }

    /// Stride along a given axis
695
    ssize_t strides(ssize_t dim) const {
696
        if (dim >= ndim())
697
            fail_dim_check(dim, "invalid axis");
698
699
700
        return strides()[dim];
    }

701
702
    /// Return the NumPy array flags
    int flags() const {
703
        return detail::array_proxy(m_ptr)->flags;
704
705
    }

706
707
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
708
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
709
710
711
712
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
713
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
714
715
    }

716
717
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
718
    template<typename... Ix> const void* data(Ix... index) const {
719
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
720
721
    }

722
723
724
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
725
    template<typename... Ix> void* mutable_data(Ix... index) {
726
        check_writeable();
727
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
728
729
730
731
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
732
    template<typename... Ix> ssize_t offset_at(Ix... index) const {
733
        if ((ssize_t) sizeof...(index) > ndim())
734
            fail_dim_check(sizeof...(index), "too many indices for an array");
735
        return byte_offset(ssize_t(index)...);
736
737
    }

738
    ssize_t offset_at() const { return 0; }
739
740
741

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
742
    template<typename... Ix> ssize_t index_at(Ix... index) const {
743
        return offset_at(index...) / itemsize();
744
745
    }

746
747
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
748
749
750
751
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
752
    template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
753
        if (PYBIND11_SILENCE_MSVC_C4127(Dims >= 0) && ndim() != Dims)
754
755
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
756
        return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
757
758
    }

759
760
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
761
762
763
764
765
     * dimensionality checking.  Unlike `mutable_unchecked()`, this does not require that the
     * underlying array have the `writable` flag.  Use with care: the array must not be destroyed or
     * reshaped for the duration of the returned object, and the caller must take care not to access
     * invalid dimensions or dimension indices.
     */
766
    template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
767
        if (PYBIND11_SILENCE_MSVC_C4127(Dims >= 0) && ndim() != Dims)
768
769
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
770
        return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
771
772
    }

773
774
775
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
776
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
777
778
    }

uentity's avatar
uentity committed
779
780
781
782
783
    /// Resize array to given shape
    /// If refcheck is true and more that one reference exist to this array
    /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
    void resize(ShapeContainer new_shape, bool refcheck = true) {
        detail::npy_api::PyArray_Dims d = {
784
785
786
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(new_shape->data()),
            int(new_shape->size())
uentity's avatar
uentity committed
787
788
        };
        // try to resize, set ordering param to -1 cause it's not used anyway
789
        auto new_array = reinterpret_steal<object>(
uentity's avatar
uentity committed
790
791
792
793
794
795
            detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
        );
        if (!new_array) throw error_already_set();
        if (isinstance<array>(new_array)) { *this = std::move(new_array); }
    }

Nick Cullen's avatar
Nick Cullen committed
796
797
798
799
800
801
802
803
804
805
806
807
    /// Optional `order` parameter omitted, to be added as needed.
    array reshape(ShapeContainer new_shape) {
        detail::npy_api::PyArray_Dims d
            = {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
        auto new_array
            = reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
        if (!new_array) {
            throw error_already_set();
        }
        return new_array;
    }

Nick Cullen's avatar
Nick Cullen committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    /// Create a view of an array in a different data type.
    /// This function may fundamentally reinterpret the data in the array.
    /// It is the responsibility of the caller to ensure that this is safe.
    /// Only supports the `dtype` argument, the `type` argument is omitted,
    /// to be added as needed.
    array view(const std::string &dtype) {
        auto &api = detail::npy_api::get();
        auto new_view = reinterpret_steal<array>(api.PyArray_View_(
            m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
        if (!new_view) {
            throw error_already_set();
        }
        return new_view;
    }

823
    /// Ensure that the argument is a NumPy array
824
825
826
827
828
829
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
830
831
    }

832
protected:
833
834
    template<typename, typename> friend struct detail::npy_format_descriptor;

835
    void fail_dim_check(ssize_t dim, const std::string& msg) const {
836
837
838
839
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

840
    template<typename... Ix> ssize_t byte_offset(Ix... index) const {
841
        check_dimensions(index...);
842
        return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
843
844
    }

845
846
    void check_writeable() const {
        if (!writeable())
847
            throw std::domain_error("array is not writeable");
848
    }
849

850
    template<typename... Ix> void check_dimensions(Ix... index) const {
851
        check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
852
853
    }

854
    void check_dimensions_impl(ssize_t, const ssize_t*) const { }
855

856
    template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
857
858
859
860
861
862
863
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
864
865
866

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
867
868
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
869
            return nullptr;
870
        }
871
        return detail::npy_api::get().PyArray_FromAny_(
872
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
873
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
874
875
};

876
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
877
878
879
880
881
private:
    struct private_ctor {};
    // Delegating constructor needed when both moving and accessing in the same constructor
    array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
        : array(std::move(shape), std::move(strides), ptr, base) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
882
public:
883
884
    static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");

885
886
    using value_type = T;

887
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
888
889
    array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
    array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
890

891
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
892
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
893
894
895
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
896

897
    // NOLINTNEXTLINE(google-explicit-constructor)
898
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
899
900
        if (!m_ptr) throw error_already_set();
    }
901

902
    explicit array_t(const buffer_info& info, handle base = handle()) : array(info, base) { }
903

904
905
    array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), std::move(strides), ptr, base) { }
906

907
    explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
908
909
910
911
912
913
        : array_t(private_ctor{},
                  std::move(shape),
                  (ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize())
                                              : detail::c_strides(*shape, itemsize()),
                  ptr,
                  base) {}
914

915
    explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
916
917
        : array({count}, {}, ptr, base) { }

918
    constexpr ssize_t itemsize() const {
919
        return sizeof(T);
920
921
    }

922
    template<typename... Ix> ssize_t index_at(Ix... index) const {
923
        return offset_at(index...) / itemsize();
924
925
    }

926
    template<typename... Ix> const T* data(Ix... index) const {
927
928
929
        return static_cast<const T*>(array::data(index...));
    }

930
    template<typename... Ix> T* mutable_data(Ix... index) {
931
932
933
934
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
935
    template<typename... Ix> const T& at(Ix... index) const {
936
        if ((ssize_t) sizeof...(index) != ndim())
937
            fail_dim_check(sizeof...(index), "index dimension mismatch");
938
        return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
939
940
941
    }

    // Mutable reference to element at a given index
942
    template<typename... Ix> T& mutable_at(Ix... index) {
943
        if ((ssize_t) sizeof...(index) != ndim())
944
            fail_dim_check(sizeof...(index), "index dimension mismatch");
945
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
946
    }
947

948
949
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
950
951
952
953
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
954
    template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
955
956
957
        return array::mutable_unchecked<T, Dims>();
    }

958
959
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
960
961
962
963
964
     * dimensionality checking.  Unlike `unchecked()`, this does not require that the underlying
     * array have the `writable` flag.  Use with care: the array must not be destroyed or reshaped
     * for the duration of the returned object, and the caller must take care not to access invalid
     * dimensions or dimension indices.
     */
965
    template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
966
967
968
        return array::unchecked<T, Dims>();
    }

Jason Rhinelander's avatar
Jason Rhinelander committed
969
970
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
971
972
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
973
974
        if (!result)
            PyErr_Clear();
975
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
976
    }
977

Wenzel Jakob's avatar
Wenzel Jakob committed
978
    static bool check_(handle h) {
979
980
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
981
982
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
               && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
983
984
985
986
987
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
988
989
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
990
            return nullptr;
991
        }
992
993
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
994
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
995
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
996
997
};

998
template <typename T>
999
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
1000
1001
1002
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
1003
1004
1005
};

template <size_t N> struct format_descriptor<char[N]> {
1006
    static std::string format() { return std::to_string(N) + "s"; }
1007
1008
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
1009
    static std::string format() { return std::to_string(N) + "s"; }
1010
1011
};

1012
1013
1014
1015
1016
1017
1018
1019
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

1020
1021
1022
template <typename T>
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
    static std::string format() {
1023
1024
1025
        using namespace detail;
        static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
        return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
1026
1027
1028
    }
};

1029
PYBIND11_NAMESPACE_BEGIN(detail)
1030
1031
1032
1033
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

1034
1035
1036
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
1037
        value = type::ensure(src);
1038
1039
1040
1041
1042
1043
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
1044
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
1045
1046
};

1047
1048
1049
1050
1051
1052
1053
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
    static bool compare(const buffer_info& b) {
        return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
    }
};

1054
1055
1056
1057
1058
1059
template <typename T, typename = void>
struct npy_format_descriptor_name;

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
    static constexpr auto name = _<std::is_same<T, bool>::value>(
1060
        _("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
1061
1062
1063
1064
1065
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
1066
1067
1068
1069
    static constexpr auto name = _<std::is_same<T, float>::value
                                   || std::is_same<T, const float>::value
                                   || std::is_same<T, double>::value
                                   || std::is_same<T, const double>::value>(
1070
        _("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
1071
1072
1073
1074
1075
1076
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
    static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
1077
1078
1079
                                   || std::is_same<typename T::value_type, const float>::value
                                   || std::is_same<typename T::value_type, double>::value
                                   || std::is_same<typename T::value_type, const double>::value>(
1080
        _("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
1081
1082
1083
1084
1085
1086
    );
};

template <typename T>
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
    : npy_format_descriptor_name<T> {
1087
private:
1088
1089
1090
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
1091
1092
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_INT16_,    npy_api::NPY_UINT16_,
        npy_api::NPY_INT32_,  npy_api::NPY_UINT32_,  npy_api::NPY_INT64_,    npy_api::NPY_UINT64_,
1093
1094
1095
1096
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

1097
public:
1098
1099
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

1100
    static pybind11::dtype dtype() {
1101
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
1102
            return reinterpret_steal<pybind11::dtype>(ptr);
1103
        pybind11_fail("Unsupported buffer format!");
1104
    }
1105
};
1106
1107

#define PYBIND11_DECL_CHAR_FMT \
1108
    static constexpr auto name = _("S") + _<N>(); \
1109
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
1110
1111
1112
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
1113

1114
1115
1116
1117
1118
1119
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
private:
    using base_descr = npy_format_descriptor<typename array_info<T>::type>;
public:
    static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");

1120
    static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
1121
1122
1123
1124
1125
1126
1127
    static pybind11::dtype dtype() {
        list shape;
        array_info<T>::append_extents(shape);
        return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
    }
};

1128
1129
1130
1131
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
1132
    static constexpr auto name = base_descr::name;
1133
1134
1135
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

1136
1137
struct field_descriptor {
    const char *name;
1138
1139
    ssize_t offset;
    ssize_t size;
1140
    std::string format;
1141
    dtype descr;
1142
1143
};

1144
PYBIND11_NOINLINE void register_structured_dtype(
1145
    any_container<field_descriptor> fields,
1146
    const std::type_info& tinfo, ssize_t itemsize,
1147
1148
    bool (*direct_converter)(PyObject *, void *&)) {

1149
1150
1151
1152
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

1153
1154
1155
1156
1157
1158
    // Use ordered fields because order matters as of NumPy 1.14:
    // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
    std::vector<field_descriptor> ordered_fields(std::move(fields));
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });

1159
    list names, formats, offsets;
1160
    for (auto& field : ordered_fields) {
1161
1162
1163
1164
1165
1166
1167
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
1168
1169
1170
1171
    auto dtype_ptr
        = pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize)
              .release()
              .ptr();
1172
1173
1174
1175
1176
1177
1178
1179

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
1180
    ssize_t offset = 0;
1181
    std::ostringstream oss;
1182
1183
1184
1185
1186
1187
    // mark the structure as unaligned with '^', because numpy and C++ don't
    // always agree about alignment (particularly for complex), and we're
    // explicitly listing all our padding. This depends on none of the fields
    // overriding the endianness. Putting the ^ in front of individual fields
    // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
    oss << "^T{";
1188
1189
1190
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
1191
        oss << field.format << ':' << field.name << ':';
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

1210
1211
1212
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

1213
    static constexpr auto name = make_caster<T>::name;
1214

1215
    static pybind11::dtype dtype() {
1216
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
1217
1218
    }

1219
    static std::string format() {
1220
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1221
        return format_str;
1222
1223
    }

1224
1225
    static void register_dtype(any_container<field_descriptor> fields) {
        register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
1226
                                  sizeof(T), &direct_converter);
1227
1228
1229
    }

private:
1230
1231
1232
1233
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
1234

1235
1236
1237
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
1238
            return false;
1239
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1240
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1241
1242
1243
1244
1245
1246
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
1247
1248
};

1249
1250
1251
1252
1253
#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
#else

1254
1255
1256
1257
1258
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1259
    }
1260

1261
1262
1263
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

1264
1265
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
peter's avatar
peter committed
1279
#if defined(_MSC_VER) && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
1280
1281
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1282
#else
1283
1284
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1285
#endif
1286
1287
1288
1289
1290
1291
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1292
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1293
1294
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1295

1296
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
1297
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1298
1299
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1300

peter's avatar
peter committed
1301
#if defined(_MSC_VER) && !defined(__clang__)
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1320
1321
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1322

1323
1324
#endif // __CLION_IDE__

1325
1326
class common_iterator {
public:
1327
    using container_type = std::vector<ssize_t>;
1328
1329
1330
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

1331
    common_iterator() : m_strides() {}
1332

1333
    common_iterator(void* ptr, const container_type& strides, const container_type& shape)
1334
1335
1336
1337
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
1338
            auto s = static_cast<value_type>(shape[i]);
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
1352
    char *p_ptr{0};
1353
1354
1355
    container_type m_strides;
};

1356
template <size_t N> class multi_array_iterator {
1357
public:
1358
    using container_type = std::vector<ssize_t>;
1359

1360
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
1361
                         const container_type &shape)
1362
1363
1364
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

1365
        // Manual copy to avoid conversion warning if using std::copy
1366
        for (size_t i = 0; i < shape.size(); ++i)
1367
            m_shape[i] = shape[i];
1368

1369
        container_type strides(shape.size());
1370
        for (size_t i = 0; i < N; ++i)
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
            }
1381
            m_index[i] = 0;
1382
1383
1384
1385
        }
        return *this;
    }

1386
1387
    template <size_t K, class T = void> T* data() const {
        return reinterpret_cast<T*>(m_common_iterator[K].data());
1388
1389
1390
1391
1392
1393
    }

private:

    using common_iter = common_iterator;

1394
    void init_common_iterator(const buffer_info &buffer,
1395
1396
1397
                              const container_type &shape,
                              common_iter &iterator,
                              container_type &strides) {
1398
1399
1400
1401
1402
1403
1404
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1405
                *strides_iter = *buffer_strides_iter;
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1420
        for (auto &iter : m_common_iterator)
1421
1422
1423
1424
1425
1426
1427
1428
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

1429
1430
1431
1432
1433
1434
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };

// Populates the shape and number of dimensions for the set of buffers.  Returns a broadcast_trivial
// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
// buffer; returns `non_trivial` otherwise.
1435
template <size_t N>
1436
broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1437
    ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1438
1439
1440
        return std::max(res, buf.ndim);
    });

1441
    shape.clear();
1442
    shape.resize((size_t) ndim, 1);
1443

1444
1445
    // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
    // the full size).
1446
1447
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
1448
1449
        auto end = buffers[i].shape.rend();
        for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
1450
1451
1452
1453
1454
1455
1456
            const auto &dim_size_in = *shape_iter;
            auto &dim_size_out = *res_iter;

            // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
            if (dim_size_out == 1)
                dim_size_out = dim_size_in;
            else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1457
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1458
1459
        }
    }
1460

1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
    bool trivial_broadcast_c = true;
    bool trivial_broadcast_f = true;
    for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
        if (buffers[i].size == 1)
            continue;

        // Require the same number of dimensions:
        if (buffers[i].ndim != ndim)
            return broadcast_trivial::non_trivial;

        // Require all dimensions be full-size:
        if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
            return broadcast_trivial::non_trivial;

        // Check for C contiguity (but only if previous inputs were also C contiguous)
        if (trivial_broadcast_c) {
1477
            ssize_t expect_stride = buffers[i].itemsize;
1478
            auto end = buffers[i].shape.crend();
1479
1480
            for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
                    trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
1481
1482
1483
1484
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_c = false;
1485
            }
1486
        }
1487

1488
1489
        // Check for Fortran contiguity (if previous inputs were also F contiguous)
        if (trivial_broadcast_f) {
1490
            ssize_t expect_stride = buffers[i].itemsize;
1491
            auto end = buffers[i].shape.cend();
1492
1493
            for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
                    trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
1494
1495
1496
1497
1498
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_f = false;
            }
1499
1500
        }
    }
1501
1502
1503
1504
1505

    return
        trivial_broadcast_c ? broadcast_trivial::c_trivial :
        trivial_broadcast_f ? broadcast_trivial::f_trivial :
        broadcast_trivial::non_trivial;
1506
1507
}

1508
1509
1510
1511
1512
1513
1514
template <typename T>
struct vectorize_arg {
    static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
    // The wrapped function gets called with this type:
    using call_type = remove_reference_t<T>;
    // Is this a vectorized argument?
    static constexpr bool vectorize =
1515
        satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value &&
1516
1517
1518
1519
1520
1521
1522
        satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
        (!std::is_reference<T>::value ||
         (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
    // Accept this type: an array for vectorized types, otherwise the type as-is:
    using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
};

1523
1524
1525
1526
1527
1528
1529
1530
1531

// py::vectorize when a return type is present
template <typename Func, typename Return, typename... Args>
struct vectorize_returned_array {
    using Type = array_t<Return>;

    static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
        if (trivial == broadcast_trivial::f_trivial)
            return array_t<Return, array::f_style>(shape);
1532
        return array_t<Return>(shape);
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
    }

    static Return *mutable_data(Type &array) {
        return array.mutable_data();
    }

    static Return call(Func &f, Args &... args) {
        return f(args...);
    }

    static void call(Return *out, size_t i, Func &f, Args &... args) {
        out[i] = f(args...);
    }
};

// py::vectorize when a return type is not present
template <typename Func, typename... Args>
struct vectorize_returned_array<Func, void, Args...> {
    using Type = none;

    static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
        return none();
    }

    static void *mutable_data(Type &) {
        return nullptr;
    }

    static detail::void_type call(Func &f, Args &... args) {
        f(args...);
        return {};
    }

    static void call(void *, size_t, Func &f, Args &... args) {
        f(args...);
    }
};


1572
1573
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
1574
1575
1576
1577
1578

// NVCC for some reason breaks if NVectorized is private
#ifdef __CUDACC__
public:
#else
1579
private:
1580
1581
#endif

1582
    static constexpr size_t N = sizeof...(Args);
1583
1584
1585
    static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
    static_assert(NVectorized >= 1,
            "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1586

1587
public:
1588
1589
1590
1591
1592
    template <typename T,
              // SFINAE to prevent shadowing the copy constructor.
              typename = detail::enable_if_t<
                  !std::is_same<vectorize_helper, typename std::decay<T>::type>::value>>
    explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
1593

1594
1595
1596
1597
1598
    object operator()(typename vectorize_arg<Args>::type... args) {
        return run(args...,
                   make_index_sequence<N>(),
                   select_indices<vectorize_arg<Args>::vectorize...>(),
                   make_index_sequence<NVectorized>());
1599
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1600

1601
1602
1603
private:
    remove_reference_t<Func> f;

1604
1605
1606
1607
    // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
    // when arg_call_types is manually inlined.
    using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
    template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
1608

1609
1610
    using returned_array = vectorize_returned_array<Func, Return, Args...>;

1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
    // Runs a vectorized function given arguments tuple and three index sequences:
    //     - Index is the full set of 0 ... (N-1) argument indices;
    //     - VIndex is the subset of argument indices with vectorized parameters, letting us access
    //       vectorized arguments (anything not in this sequence is passed through)
    //     - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
    //       we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
    //       index BIndex in the array).
    template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
            typename vectorize_arg<Args>::type &...args,
            index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {

        // Pointers to values the function was called with; the vectorized ones set here will start
        // out as array_t<T> pointers, but they will be changed them to T pointers before we make
        // call the wrapped function.  Non-vectorized pointers are left as-is.
        std::array<void *, N> params{{ &args... }};

        // The array of `buffer_info`s of vectorized arguments:
        std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
Wenzel Jakob's avatar
Wenzel Jakob committed
1629
1630

        /* Determine dimensions parameters of output array */
1631
1632
1633
        ssize_t nd = 0;
        std::vector<ssize_t> shape(0);
        auto trivial = broadcast(buffers, nd, shape);
1634
        auto ndim = (size_t) nd;
1635

1636
1637
1638
1639
1640
1641
        size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());

        // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
        // not wrapped in an array).
        if (size == 1 && ndim == 0) {
            PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
1642
            return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1643
1644
        }

1645
        auto result = returned_array::create(trivial, shape);
Wenzel Jakob's avatar
Wenzel Jakob committed
1646

Henry Schreiner's avatar
Henry Schreiner committed
1647
        if (size == 0) return std::move(result);
1648

1649
        /* Call the function */
1650
        auto mutable_data = returned_array::mutable_data(result);
1651
        if (trivial == broadcast_trivial::non_trivial)
1652
            apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
1653
        else
1654
            apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
1655

Henry Schreiner's avatar
Henry Schreiner committed
1656
        return std::move(result);
1657
    }
1658

1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
                       std::array<void *, N> &params,
                       Return *out,
                       size_t size,
                       index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {

        // Initialize an array of mutable byte references and sizes with references set to the
        // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
        // (except for singletons, which get an increment of 0).
        std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
            std::pair<unsigned char *&, const size_t>(
                    reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
                    buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
            )...
        }};

        for (size_t i = 0; i < size; ++i) {
1677
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
1678
1679
1680
1681
1682
1683
1684
            for (auto &x : vecparams) x.first += x.second;
        }
    }

    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
                         std::array<void *, N> &params,
1685
1686
1687
                         Return *out,
                         size_t size,
                         const std::vector<ssize_t> &output_shape,
1688
                         index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
1689

1690
        multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
1691

1692
        for (size_t i = 0; i < size; ++i, ++input_iter) {
1693
1694
1695
            PYBIND11_EXPAND_SIDE_EFFECTS((
                params[VIndex] = input_iter.template data<BIndex>()
            ));
1696
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
1697
1698
        }
    }
1699
1700
};

1701
1702
1703
1704
1705
1706
template <typename Func, typename Return, typename... Args>
vectorize_helper<Func, Return, Args...>
vectorize_extractor(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
}

1707
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1708
    static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
1709
1710
};

1711
PYBIND11_NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1712

1713
// Vanilla pointer vectorizer:
1714
template <typename Return, typename... Args>
1715
detail::vectorize_helper<Return (*)(Args...), Return, Args...>
1716
vectorize(Return (*f) (Args ...)) {
1717
    return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1718
1719
}

1720
// lambda vectorizer:
1721
template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
1722
auto vectorize(Func &&f) -> decltype(
1723
1724
        detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
    return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
1725
1726
1727
1728
1729
1730
1731
1732
1733
}

// Vectorize a class method (non-const):
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...)) {
    return Helper(std::mem_fn(f));
}

1734
// Vectorize a class method (const):
1735
1736
1737
1738
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...) const) {
    return Helper(std::mem_fn(f));
Wenzel Jakob's avatar
Wenzel Jakob committed
1739
1740
}

1741
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)