numpy.h 68.1 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdint>
18
#include <cstdlib>
19
#include <cstring>
20
#include <sstream>
21
#include <string>
22
#include <functional>
23
#include <type_traits>
24
#include <utility>
25
#include <vector>
26
#include <typeindex>
27

Wenzel Jakob's avatar
Wenzel Jakob committed
28
#if defined(_MSC_VER)
29
30
#  pragma warning(push)
#  pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Wenzel Jakob's avatar
Wenzel Jakob committed
31
32
#endif

33
/* This will be true on all flat address space platforms and allows us to reduce the
34
   whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
35
36
   and dimension types (e.g. shape, strides, indexing), instead of inflicting this
   upon the library user. */
37
38
39
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
40

41
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
42
43
44

class array; // Forward declaration

45
PYBIND11_NAMESPACE_BEGIN(detail)
46
47
48

template <> struct handle_type_name<array> { static constexpr auto name = _("numpy.ndarray"); };

49
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
struct PyArrayDescr_Proxy {
    PyObject_HEAD
    PyObject *typeobj;
    char kind;
    char type;
    char byteorder;
    char flags;
    int type_num;
    int elsize;
    int alignment;
    char *subarray;
    PyObject *fields;
    PyObject *names;
};

struct PyArray_Proxy {
    PyObject_HEAD
    char *data;
    int nd;
    ssize_t *dimensions;
    ssize_t *strides;
    PyObject *base;
    PyObject *descr;
    int flags;
};

77
78
79
80
81
82
83
84
struct PyVoidScalarObject_Proxy {
    PyObject_VAR_HEAD
    char *obval;
    PyArrayDescr_Proxy *descr;
    int flags;
    PyObject *base;
};

85
86
87
88
89
90
91
92
struct numpy_type_info {
    PyObject* dtype_ptr;
    std::string format_str;
};

struct numpy_internals {
    std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

93
94
    numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
        auto it = registered_dtypes.find(std::type_index(tinfo));
95
96
97
        if (it != registered_dtypes.end())
            return &(it->second);
        if (throw_if_missing)
98
            pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
99
100
        return nullptr;
    }
101
102
103
104

    template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
        return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
    }
105
106
};

Ivan Smirnov's avatar
Ivan Smirnov committed
107
108
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
    ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
109
110
111
}

inline numpy_internals& get_numpy_internals() {
Ivan Smirnov's avatar
Ivan Smirnov committed
112
113
114
    static numpy_internals* ptr = nullptr;
    if (!ptr)
        load_numpy_internals(ptr);
115
116
117
    return *ptr;
}

118
119
120
121
template <typename T> struct same_size {
    template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
};

122
123
template <typename Concrete> constexpr int platform_lookup() { return -1; }

124
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
125
126
127
template <typename Concrete, typename T, typename... Ts, typename... Ints>
constexpr int platform_lookup(int I, Ints... Is) {
    return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
128
129
}

130
131
struct npy_api {
    enum constants {
132
133
        NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
        NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
134
        NPY_ARRAY_OWNDATA_ = 0x0004,
135
        NPY_ARRAY_FORCECAST_ = 0x0010,
136
        NPY_ARRAY_ENSUREARRAY_ = 0x0040,
137
138
        NPY_ARRAY_ALIGNED_ = 0x0100,
        NPY_ARRAY_WRITEABLE_ = 0x0400,
139
140
141
142
143
144
145
146
147
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
        // Platform-dependent normalization
        NPY_INT8_ = NPY_BYTE_,
        NPY_UINT8_ = NPY_UBYTE_,
        NPY_INT16_ = NPY_SHORT_,
        NPY_UINT16_ = NPY_USHORT_,
        // `npy_common.h` defines the integer aliases. In order, it checks:
        // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
        // and assigns the alias to the first matching size, so we should check in this order.
        NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
            NPY_LONG_, NPY_INT_, NPY_SHORT_),
        NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
            NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
        NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
            NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
        NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
            NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
165
166
    };

uentity's avatar
uentity committed
167
168
169
170
171
    typedef struct {
        Py_intptr_t *ptr;
        int len;
    } PyArray_Dims;

172
173
174
175
176
    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

177
178
179
180
181
182
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
183

184
    unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
185
186
    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
187
188
189
        (PyTypeObject *, PyObject *, int, Py_intptr_t const *,
         Py_intptr_t const *, void *, int, PyObject *);
    // Unused. Not removed because that affects ABI of the class.
190
    PyObject *(*PyArray_DescrNewFromType_)(int);
191
    int (*PyArray_CopyInto_)(PyObject *, PyObject *);
192
193
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
194
    PyTypeObject *PyVoidArrType_Type_;
195
    PyTypeObject *PyArrayDescr_Type_;
196
    PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
197
198
199
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
200
201
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, unsigned char, PyObject **, int *,
                                             Py_intptr_t *, PyObject **, PyObject *);
202
    PyObject *(*PyArray_Squeeze_)(PyObject *);
203
    // Unused. Not removed because that affects ABI of the class.
Jason Rhinelander's avatar
Jason Rhinelander committed
204
    int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
uentity's avatar
uentity committed
205
    PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
206
207
private:
    enum functions {
208
        API_PyArray_GetNDArrayCFeatureVersion = 211,
209
        API_PyArray_Type = 2,
210
        API_PyArrayDescr_Type = 3,
211
        API_PyVoidArrType_Type = 39,
212
        API_PyArray_DescrFromType = 45,
213
        API_PyArray_DescrFromScalar = 57,
214
        API_PyArray_FromAny = 69,
uentity's avatar
uentity committed
215
        API_PyArray_Resize = 80,
216
        API_PyArray_CopyInto = 82,
217
218
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
219
        API_PyArray_DescrNewFromType = 96,
220
221
222
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
Jason Rhinelander's avatar
Jason Rhinelander committed
223
224
        API_PyArray_Squeeze = 136,
        API_PyArray_SetBaseObject = 282
225
226
227
    };

    static npy_api lookup() {
228
        module_ m = module_::import("numpy.core.multiarray");
229
        auto c = m.attr("_ARRAY_API");
230
#if PY_MAJOR_VERSION >= 3
231
        void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
232
#else
233
        void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
234
#endif
235
        npy_api api;
236
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
237
238
239
        DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
        if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
            pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
240
        DECL_NPY_API(PyArray_Type);
241
        DECL_NPY_API(PyVoidArrType_Type);
242
        DECL_NPY_API(PyArrayDescr_Type);
243
        DECL_NPY_API(PyArray_DescrFromType);
244
        DECL_NPY_API(PyArray_DescrFromScalar);
245
        DECL_NPY_API(PyArray_FromAny);
uentity's avatar
uentity committed
246
        DECL_NPY_API(PyArray_Resize);
247
        DECL_NPY_API(PyArray_CopyInto);
248
249
250
251
252
253
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
254
        DECL_NPY_API(PyArray_Squeeze);
Jason Rhinelander's avatar
Jason Rhinelander committed
255
        DECL_NPY_API(PyArray_SetBaseObject);
256
#undef DECL_NPY_API
257
258
259
        return api;
    }
};
Wenzel Jakob's avatar
Wenzel Jakob committed
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
inline PyArray_Proxy* array_proxy(void* ptr) {
    return reinterpret_cast<PyArray_Proxy*>(ptr);
}

inline const PyArray_Proxy* array_proxy(const void* ptr) {
    return reinterpret_cast<const PyArray_Proxy*>(ptr);
}

inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
   return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
}

inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
   return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
}

inline bool check_flags(const void* ptr, int flag) {
    return (flag == (array_proxy(ptr)->flags & flag));
}

281
282
283
284
285
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

286
template <typename T> struct array_info_scalar {
287
    using type = T;
288
289
    static constexpr bool is_array = false;
    static constexpr bool is_empty = false;
290
    static constexpr auto extents = _("");
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    static void append_extents(list& /* shape */) { }
};
// Computes underlying type and a comma-separated list of extents for array
// types (any mix of std::array and built-in arrays). An array of char is
// treated as scalar because it gets special handling.
template <typename T> struct array_info : array_info_scalar<T> { };
template <typename T, size_t N> struct array_info<std::array<T, N>> {
    using type = typename array_info<T>::type;
    static constexpr bool is_array = true;
    static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
    static constexpr size_t extent = N;

    // appends the extents to shape
    static void append_extents(list& shape) {
        shape.append(N);
        array_info<T>::append_extents(shape);
    }

309
310
311
    static constexpr auto extents = _<array_info<T>::is_array>(
        concat(_<N>(), array_info<T>::extents), _<N>()
    );
312
313
314
315
316
317
318
319
};
// For numpy we have special handling for arrays of characters, so we don't include
// the size in the array extents.
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
template <typename T> using remove_all_extents_t = typename array_info<T>::type;

320
template <typename T> using is_pod_struct = all_of<
321
    std::is_standard_layout<T>,     // since we're accessing directly in memory we need a standard layout type
322
323
324
#if defined(__GLIBCXX__) && (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150623 || __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
    // libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after 5)
    // don't implement is_trivially_copyable, so approximate it
325
326
    std::is_trivially_destructible<T>,
    satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
327
328
#else
    std::is_trivially_copyable<T>,
329
#endif
330
331
332
    satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

333
334
335
336
337
338
// Replacement for std::is_pod (deprecated in C++20)
template <typename T> using is_pod = all_of<
    std::is_standard_layout<T>,
    std::is_trivial<T>
>;

339
340
341
342
template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
template <ssize_t Dim = 0, typename Strides, typename... Ix>
ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
    return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
343
344
}

345
346
/**
 * Proxy class providing unsafe, unchecked const access to array data.  This is constructed through
347
348
 * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.  `Dims`
 * will be -1 for dimensions determined at runtime.
349
 */
350
template <typename T, ssize_t Dims>
351
352
class unchecked_reference {
protected:
353
    static constexpr bool Dynamic = Dims < 0;
354
355
    const unsigned char *data_;
    // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
356
    // make large performance gains on big, nested loops, but requires compile-time dimensions
357
358
359
    conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
            shape_, strides_;
    const ssize_t dims_;
360
361

    friend class pybind11::array;
362
363
    // Constructor for compile-time dimensions:
    template <bool Dyn = Dynamic>
364
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
365
    : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
366
        for (size_t i = 0; i < (size_t) dims_; i++) {
367
368
369
370
            shape_[i] = shape[i];
            strides_[i] = strides[i];
        }
    }
371
372
    // Constructor for runtime dimensions:
    template <bool Dyn = Dynamic>
373
    unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
374
    : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
375
376

public:
377
378
    /**
     * Unchecked const reference access to data at the given indices.  For a compile-time known
379
380
     * number of dimensions, this requires the correct number of arguments; for run-time
     * dimensionality, this is not checked (and so is up to the caller to use safely).
381
     */
382
    template <typename... Ix> const T &operator()(Ix... index) const {
Jason Rhinelander's avatar
Jason Rhinelander committed
383
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
384
                "Invalid number of indices for unchecked array reference");
385
        return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
386
    }
387
388
    /**
     * Unchecked const reference access to data; this operator only participates if the reference
389
390
     * is to a 1-dimensional array.  When present, this is exactly equivalent to `obj(index)`.
     */
391
392
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    const T &operator[](ssize_t index) const { return operator()(index); }
393

394
    /// Pointer access to the data at the given indices.
395
    template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
396
397

    /// Returns the item size, i.e. sizeof(T)
398
    constexpr static ssize_t itemsize() { return sizeof(T); }
399

400
    /// Returns the shape (i.e. size) of dimension `dim`
401
    ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
402
403

    /// Returns the number of dimensions of the array
404
    ssize_t ndim() const { return dims_; }
405
406
407

    /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
    template <bool Dyn = Dynamic>
408
409
    enable_if_t<!Dyn, ssize_t> size() const {
        return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
410
411
    }
    template <bool Dyn = Dynamic>
412
413
    enable_if_t<Dyn, ssize_t> size() const {
        return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
414
415
416
417
    }

    /// Returns the total number of bytes used by the referenced data.  Note that the actual span in
    /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
418
    ssize_t nbytes() const {
419
420
        return size() * itemsize();
    }
421
422
};

423
template <typename T, ssize_t Dims>
424
425
426
427
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
    friend class pybind11::array;
    using ConstBase = unchecked_reference<T, Dims>;
    using ConstBase::ConstBase;
428
    using ConstBase::Dynamic;
429
public:
430
431
432
433
    // Bring in const-qualified versions from base class
    using ConstBase::operator();
    using ConstBase::operator[];

434
435
    /// Mutable, unchecked access to data at the given indices.
    template <typename... Ix> T& operator()(Ix... index) {
Jason Rhinelander's avatar
Jason Rhinelander committed
436
        static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
437
                "Invalid number of indices for unchecked array reference");
438
439
        return const_cast<T &>(ConstBase::operator()(index...));
    }
440
441
    /**
     * Mutable, unchecked access data at the given index; this operator only participates if the
442
443
     * reference is to a 1-dimensional array (or has runtime dimensions).  When present, this is
     * exactly equivalent to `obj(index)`.
444
     */
445
446
    template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
    T &operator[](ssize_t index) { return operator()(index); }
447
448

    /// Mutable pointer access to the data at the given indices.
449
    template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
450
451
};

452
template <typename T, ssize_t Dim>
453
struct type_caster<unchecked_reference<T, Dim>> {
454
    static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
455
};
456
template <typename T, ssize_t Dim>
457
458
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};

459
PYBIND11_NAMESPACE_END(detail)
460

461
class dtype : public object {
462
public:
463
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
464

465
    explicit dtype(const buffer_info &info) {
466
        dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
467
468
        // If info.itemsize == 0, use the value calculated from the format string
        m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
469
    }
470

471
    explicit dtype(const std::string &format) {
472
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
473
474
    }

475
    dtype(const char *format) : dtype(std::string(format)) { }
476

477
    dtype(list names, list formats, list offsets, ssize_t itemsize) {
478
479
480
481
        dict args;
        args["names"] = names;
        args["formats"] = formats;
        args["offsets"] = offsets;
482
        args["itemsize"] = pybind11::int_(itemsize);
483
484
485
        m_ptr = from_args(args).release().ptr();
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
486
    /// This is essentially the same as calling numpy.dtype(args) in Python.
487
488
    static dtype from_args(object args) {
        PyObject *ptr = nullptr;
Jason Rhinelander's avatar
Jason Rhinelander committed
489
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr)
490
            throw error_already_set();
491
        return reinterpret_steal<dtype>(ptr);
492
    }
493

Ivan Smirnov's avatar
Ivan Smirnov committed
494
    /// Return dtype associated with a C++ type.
495
    template <typename T> static dtype of() {
496
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
497
    }
498

Ivan Smirnov's avatar
Ivan Smirnov committed
499
    /// Size of the data type in bytes.
500
501
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(m_ptr)->elsize;
Wenzel Jakob's avatar
Wenzel Jakob committed
502
503
    }

Ivan Smirnov's avatar
Ivan Smirnov committed
504
    /// Returns true for structured data types.
505
    bool has_fields() const {
506
        return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
507
508
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
509
510
    /// Single-character code for dtype's kind.
    /// For example, floating point types are 'f' and integral types are 'i'.
511
    char kind() const {
512
        return detail::array_descriptor_proxy(m_ptr)->kind;
513
514
    }

Bertrand MICHEL's avatar
Bertrand MICHEL committed
515
516
517
518
519
520
521
522
523
    /// Single-character for dtype's type.
    /// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'd'.
    char char_() const {
        // Note: The signature, `dtype::char_` follows the naming of NumPy's
        // public Python API (i.e., ``dtype.char``), rather than its internal
        // C API (``PyArray_Descr::type``).
        return detail::array_descriptor_proxy(m_ptr)->type;
    }

524
private:
525
    static object _dtype_from_pep3118() {
526
        static PyObject *obj = module_::import("numpy.core._internal")
527
            .attr("_dtype_from_pep3118").cast<object>().release().ptr();
528
        return reinterpret_borrow<object>(obj);
529
    }
530

531
    dtype strip_padding(ssize_t itemsize) {
532
533
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
534
        if (!has_fields())
535
            return *this;
536

537
        struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
538
539
        std::vector<field_descr> field_descriptors;

540
        for (auto field : attr("fields").attr("items")()) {
541
            auto spec = field.cast<tuple>();
542
            auto name = spec[0].cast<pybind11::str>();
543
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
544
            auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
545
            if (!len(name) && format.kind() == 'V')
546
                continue;
547
            field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
548
549
550
551
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
552
                      return a.offset.cast<int>() < b.offset.cast<int>();
553
554
555
556
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
557
558
559
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
560
        }
561
        return dtype(names, formats, offsets, itemsize);
562
563
    }
};
564

565
566
class array : public buffer {
public:
567
    PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
568
569

    enum {
570
571
        c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
572
573
574
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

575
    array() : array(0, static_cast<const double *>(nullptr)) {}
576

577
578
    using ShapeContainer = detail::any_container<ssize_t>;
    using StridesContainer = detail::any_container<ssize_t>;
579
580
581
582
583
584

    // Constructs an array taking shape/strides from arbitrary container types
    array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
          const void *ptr = nullptr, handle base = handle()) {

        if (strides->empty())
585
            *strides = detail::c_strides(*shape, dt.itemsize());
586
587
588

        auto ndim = shape->size();
        if (ndim != strides->size())
589
590
            pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
        auto descr = dt;
591
592
593

        int flags = 0;
        if (base && ptr) {
594
            if (isinstance<array>(base))
Wenzel Jakob's avatar
Wenzel Jakob committed
595
                /* Copy flags from base (except ownership bit) */
596
                flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
597
598
599
600
601
            else
                /* Writable by default, easy to downgrade later on if needed */
                flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
        }

602
        auto &api = detail::npy_api::get();
603
        auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
604
605
606
607
            api.PyArray_Type_, descr.release().ptr(), (int) ndim,
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(shape->data()),
            reinterpret_cast<Py_intptr_t*>(strides->data()),
608
            const_cast<void *>(ptr), flags, nullptr));
609
        if (!tmp)
610
            throw error_already_set();
611
612
        if (ptr) {
            if (base) {
Jason Rhinelander's avatar
Jason Rhinelander committed
613
                api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
614
            } else {
615
                tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
616
617
            }
        }
618
619
620
        m_ptr = tmp.release().ptr();
    }

621
622
    array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
        : array(dt, std::move(shape), {}, ptr, base) { }
623

624
625
626
    template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
    array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
        : array(dt, {{count}}, ptr, base) { }
627

628
629
630
    template <typename T>
    array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
        : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
631

632
    template <typename T>
633
634
    array(ShapeContainer shape, const T *ptr, handle base = handle())
        : array(std::move(shape), {}, ptr, base) { }
635

636
    template <typename T>
637
    explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
638

639
640
    explicit array(const buffer_info &info, handle base = handle())
    : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) { }
641

642
643
    /// Array descriptor (dtype)
    pybind11::dtype dtype() const {
644
        return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
645
646
647
    }

    /// Total number of elements
648
649
    ssize_t size() const {
        return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
650
651
652
    }

    /// Byte size of a single element
653
654
    ssize_t itemsize() const {
        return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
655
656
657
    }

    /// Total number of bytes
658
    ssize_t nbytes() const {
659
660
661
662
        return size() * itemsize();
    }

    /// Number of dimensions
663
664
    ssize_t ndim() const {
        return detail::array_proxy(m_ptr)->nd;
665
666
    }

667
668
    /// Base object
    object base() const {
669
        return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
670
671
    }

672
    /// Dimensions of the array
673
674
    const ssize_t* shape() const {
        return detail::array_proxy(m_ptr)->dimensions;
675
676
677
    }

    /// Dimension along a given axis
678
    ssize_t shape(ssize_t dim) const {
679
        if (dim >= ndim())
680
            fail_dim_check(dim, "invalid axis");
681
682
683
684
        return shape()[dim];
    }

    /// Strides of the array
685
    const ssize_t* strides() const {
686
        return detail::array_proxy(m_ptr)->strides;
687
688
689
    }

    /// Stride along a given axis
690
    ssize_t strides(ssize_t dim) const {
691
        if (dim >= ndim())
692
            fail_dim_check(dim, "invalid axis");
693
694
695
        return strides()[dim];
    }

696
697
    /// Return the NumPy array flags
    int flags() const {
698
        return detail::array_proxy(m_ptr)->flags;
699
700
    }

701
702
    /// If set, the array is writeable (otherwise the buffer is read-only)
    bool writeable() const {
703
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
704
705
706
707
    }

    /// If set, the array owns the data (will be freed when the array is deleted)
    bool owndata() const {
708
        return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
709
710
    }

711
712
    /// Pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
713
    template<typename... Ix> const void* data(Ix... index) const {
714
        return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
715
716
    }

717
718
719
    /// Mutable pointer to the contained data. If index is not provided, points to the
    /// beginning of the buffer. May throw if the index would lead to out of bounds access.
    /// May throw if the array is not writeable.
720
    template<typename... Ix> void* mutable_data(Ix... index) {
721
        check_writeable();
722
        return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
723
724
725
726
    }

    /// Byte offset from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
727
    template<typename... Ix> ssize_t offset_at(Ix... index) const {
728
        if ((ssize_t) sizeof...(index) > ndim())
729
            fail_dim_check(sizeof...(index), "too many indices for an array");
730
        return byte_offset(ssize_t(index)...);
731
732
    }

733
    ssize_t offset_at() const { return 0; }
734
735
736

    /// Item count from beginning of the array to a given index (full or partial).
    /// May throw if the index would lead to out of bounds access.
737
    template<typename... Ix> ssize_t index_at(Ix... index) const {
738
        return offset_at(index...) / itemsize();
739
740
    }

741
742
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
743
744
745
746
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
747
    template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
748
        if (Dims >= 0 && ndim() != Dims)
749
750
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
751
        return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
752
753
    }

754
755
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
756
757
758
759
760
     * dimensionality checking.  Unlike `mutable_unchecked()`, this does not require that the
     * underlying array have the `writable` flag.  Use with care: the array must not be destroyed or
     * reshaped for the duration of the returned object, and the caller must take care not to access
     * invalid dimensions or dimension indices.
     */
761
    template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
762
        if (Dims >= 0 && ndim() != Dims)
763
764
            throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
                    "; expected " + std::to_string(Dims));
765
        return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
766
767
    }

768
769
770
    /// Return a new view with all of the dimensions of length 1 removed
    array squeeze() {
        auto& api = detail::npy_api::get();
771
        return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
772
773
    }

uentity's avatar
uentity committed
774
775
776
777
778
    /// Resize array to given shape
    /// If refcheck is true and more that one reference exist to this array
    /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
    void resize(ShapeContainer new_shape, bool refcheck = true) {
        detail::npy_api::PyArray_Dims d = {
779
780
781
            // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
            reinterpret_cast<Py_intptr_t*>(new_shape->data()),
            int(new_shape->size())
uentity's avatar
uentity committed
782
783
        };
        // try to resize, set ordering param to -1 cause it's not used anyway
784
        auto new_array = reinterpret_steal<object>(
uentity's avatar
uentity committed
785
786
787
788
789
790
            detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
        );
        if (!new_array) throw error_already_set();
        if (isinstance<array>(new_array)) { *this = std::move(new_array); }
    }

791
    /// Ensure that the argument is a NumPy array
792
793
794
795
796
797
    /// In case of an error, nullptr is returned and the Python error is cleared.
    static array ensure(handle h, int ExtraFlags = 0) {
        auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
        if (!result)
            PyErr_Clear();
        return result;
798
799
    }

800
protected:
801
802
    template<typename, typename> friend struct detail::npy_format_descriptor;

803
    void fail_dim_check(ssize_t dim, const std::string& msg) const {
804
805
806
807
        throw index_error(msg + ": " + std::to_string(dim) +
                          " (ndim = " + std::to_string(ndim()) + ")");
    }

808
    template<typename... Ix> ssize_t byte_offset(Ix... index) const {
809
        check_dimensions(index...);
810
        return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
811
812
    }

813
814
    void check_writeable() const {
        if (!writeable())
815
            throw std::domain_error("array is not writeable");
816
    }
817

818
    template<typename... Ix> void check_dimensions(Ix... index) const {
819
        check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
820
821
    }

822
    void check_dimensions_impl(ssize_t, const ssize_t*) const { }
823

824
    template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
825
826
827
828
829
830
831
        if (i >= *shape) {
            throw index_error(std::string("index ") + std::to_string(i) +
                              " is out of bounds for axis " + std::to_string(axis) +
                              " with size " + std::to_string(*shape));
        }
        check_dimensions_impl(axis + 1, shape + 1, index...);
    }
832
833
834

    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
835
836
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
837
            return nullptr;
838
        }
839
        return detail::npy_api::get().PyArray_FromAny_(
840
            ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
841
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
842
843
};

844
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
845
846
847
848
849
private:
    struct private_ctor {};
    // Delegating constructor needed when both moving and accessing in the same constructor
    array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
        : array(std::move(shape), std::move(strides), ptr, base) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
850
public:
851
852
    static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");

853
854
    using value_type = T;

855
    array_t() : array(0, static_cast<const T *>(nullptr)) {}
856
857
    array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
    array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
858

859
    PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
860
    array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
861
862
863
        if (!m_ptr) PyErr_Clear();
        if (!is_borrowed) Py_XDECREF(h.ptr());
    }
864

865
    array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
866
867
        if (!m_ptr) throw error_already_set();
    }
868

869
    explicit array_t(const buffer_info& info, handle base = handle()) : array(info, base) { }
870

871
872
    array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
        : array(std::move(shape), std::move(strides), ptr, base) { }
873

874
    explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
875
        : array_t(private_ctor{}, std::move(shape),
876
877
878
                ExtraFlags & f_style
                ? detail::f_strides(*shape, itemsize())
                : detail::c_strides(*shape, itemsize()),
879
                ptr, base) { }
880

881
    explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
882
883
        : array({count}, {}, ptr, base) { }

884
    constexpr ssize_t itemsize() const {
885
        return sizeof(T);
886
887
    }

888
    template<typename... Ix> ssize_t index_at(Ix... index) const {
889
        return offset_at(index...) / itemsize();
890
891
    }

892
    template<typename... Ix> const T* data(Ix... index) const {
893
894
895
        return static_cast<const T*>(array::data(index...));
    }

896
    template<typename... Ix> T* mutable_data(Ix... index) {
897
898
899
900
        return static_cast<T*>(array::mutable_data(index...));
    }

    // Reference to element at a given index
901
    template<typename... Ix> const T& at(Ix... index) const {
902
        if ((ssize_t) sizeof...(index) != ndim())
903
            fail_dim_check(sizeof...(index), "index dimension mismatch");
904
        return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
905
906
907
    }

    // Mutable reference to element at a given index
908
    template<typename... Ix> T& mutable_at(Ix... index) {
909
        if ((ssize_t) sizeof...(index) != ndim())
910
            fail_dim_check(sizeof...(index), "index dimension mismatch");
911
        return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
912
    }
913

914
915
    /**
     * Returns a proxy object that provides access to the array's data without bounds or
916
917
918
919
     * dimensionality checking.  Will throw if the array is missing the `writeable` flag.  Use with
     * care: the array must not be destroyed or reshaped for the duration of the returned object,
     * and the caller must take care not to access invalid dimensions or dimension indices.
     */
920
    template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
921
922
923
        return array::mutable_unchecked<T, Dims>();
    }

924
925
    /**
     * Returns a proxy object that provides const access to the array's data without bounds or
926
927
928
929
930
     * dimensionality checking.  Unlike `unchecked()`, this does not require that the underlying
     * array have the `writable` flag.  Use with care: the array must not be destroyed or reshaped
     * for the duration of the returned object, and the caller must take care not to access invalid
     * dimensions or dimension indices.
     */
931
    template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
932
933
934
        return array::unchecked<T, Dims>();
    }

Jason Rhinelander's avatar
Jason Rhinelander committed
935
936
    /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
    /// it).  In case of an error, nullptr is returned and the Python error is cleared.
937
938
    static array_t ensure(handle h) {
        auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
939
940
        if (!result)
            PyErr_Clear();
941
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
942
    }
943

Wenzel Jakob's avatar
Wenzel Jakob committed
944
    static bool check_(handle h) {
945
946
        const auto &api = detail::npy_api::get();
        return api.PyArray_Check_(h.ptr())
947
948
               && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
               && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
949
950
951
952
953
    }

protected:
    /// Create array from any object -- always returns a new reference
    static PyObject *raw_array_t(PyObject *ptr) {
954
955
        if (ptr == nullptr) {
            PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
956
            return nullptr;
957
        }
958
959
        return detail::npy_api::get().PyArray_FromAny_(
            ptr, dtype::of<T>().release().ptr(), 0, 0,
960
            detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
961
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
962
963
};

964
template <typename T>
965
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
966
967
968
    static std::string format() {
        return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
    }
969
970
971
};

template <size_t N> struct format_descriptor<char[N]> {
972
    static std::string format() { return std::to_string(N) + "s"; }
973
974
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
975
    static std::string format() { return std::to_string(N) + "s"; }
976
977
};

978
979
980
981
982
983
984
985
template <typename T>
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
    static std::string format() {
        return format_descriptor<
            typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
    }
};

986
987
988
template <typename T>
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
    static std::string format() {
989
990
991
        using namespace detail;
        static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
        return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
992
993
994
    }
};

995
PYBIND11_NAMESPACE_BEGIN(detail)
996
997
998
999
template <typename T, int ExtraFlags>
struct pyobject_caster<array_t<T, ExtraFlags>> {
    using type = array_t<T, ExtraFlags>;

1000
1001
1002
    bool load(handle src, bool convert) {
        if (!convert && !type::check_(src))
            return false;
1003
        value = type::ensure(src);
1004
1005
1006
1007
1008
1009
        return static_cast<bool>(value);
    }

    static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
        return src.inc_ref();
    }
1010
    PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
1011
1012
};

1013
1014
1015
1016
1017
1018
1019
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
    static bool compare(const buffer_info& b) {
        return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
    }
};

1020
1021
1022
1023
1024
1025
template <typename T, typename = void>
struct npy_format_descriptor_name;

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
    static constexpr auto name = _<std::is_same<T, bool>::value>(
1026
        _("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
1027
1028
1029
1030
1031
1032
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
    static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
1033
        _("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
1034
1035
1036
1037
1038
1039
1040
    );
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
    static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
                                   || std::is_same<typename T::value_type, double>::value>(
1041
        _("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
1042
1043
1044
1045
1046
1047
    );
};

template <typename T>
struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
    : npy_format_descriptor_name<T> {
1048
private:
1049
1050
1051
    // NB: the order here must match the one in common.h
    constexpr static const int values[15] = {
        npy_api::NPY_BOOL_,
1052
1053
        npy_api::NPY_BYTE_,   npy_api::NPY_UBYTE_,   npy_api::NPY_INT16_,    npy_api::NPY_UINT16_,
        npy_api::NPY_INT32_,  npy_api::NPY_UINT32_,  npy_api::NPY_INT64_,    npy_api::NPY_UINT64_,
1054
1055
1056
1057
        npy_api::NPY_FLOAT_,  npy_api::NPY_DOUBLE_,  npy_api::NPY_LONGDOUBLE_,
        npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
    };

1058
public:
1059
1060
    static constexpr int value = values[detail::is_fmt_numeric<T>::index];

1061
    static pybind11::dtype dtype() {
1062
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
1063
            return reinterpret_steal<pybind11::dtype>(ptr);
1064
        pybind11_fail("Unsupported buffer format!");
1065
    }
1066
};
1067
1068

#define PYBIND11_DECL_CHAR_FMT \
1069
    static constexpr auto name = _("S") + _<N>(); \
1070
    static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
1071
1072
1073
template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
#undef PYBIND11_DECL_CHAR_FMT
1074

1075
1076
1077
1078
1079
1080
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
private:
    using base_descr = npy_format_descriptor<typename array_info<T>::type>;
public:
    static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");

1081
    static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
1082
1083
1084
1085
1086
1087
1088
    static pybind11::dtype dtype() {
        list shape;
        array_info<T>::append_extents(shape);
        return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
    }
};

1089
1090
1091
1092
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
private:
    using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
public:
1093
    static constexpr auto name = base_descr::name;
1094
1095
1096
    static pybind11::dtype dtype() { return base_descr::dtype(); }
};

1097
1098
struct field_descriptor {
    const char *name;
1099
1100
    ssize_t offset;
    ssize_t size;
1101
    std::string format;
1102
    dtype descr;
1103
1104
};

1105
inline PYBIND11_NOINLINE void register_structured_dtype(
1106
    any_container<field_descriptor> fields,
1107
    const std::type_info& tinfo, ssize_t itemsize,
1108
1109
    bool (*direct_converter)(PyObject *, void *&)) {

1110
1111
1112
1113
    auto& numpy_internals = get_numpy_internals();
    if (numpy_internals.get_type_info(tinfo, false))
        pybind11_fail("NumPy: dtype is already registered");

1114
1115
1116
1117
1118
1119
    // Use ordered fields because order matters as of NumPy 1.14:
    // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
    std::vector<field_descriptor> ordered_fields(std::move(fields));
    std::sort(ordered_fields.begin(), ordered_fields.end(),
        [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });

1120
    list names, formats, offsets;
1121
    for (auto& field : ordered_fields) {
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        if (!field.descr)
            pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
                            field.name + "` @ " + tinfo.name());
        names.append(PYBIND11_STR_TYPE(field.name));
        formats.append(field.descr);
        offsets.append(pybind11::int_(field.offset));
    }
    auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

    // There is an existing bug in NumPy (as of v1.11): trailing bytes are
    // not encoded explicitly into the format string. This will supposedly
    // get fixed in v1.12; for further details, see these:
    // - https://github.com/numpy/numpy/issues/7797
    // - https://github.com/numpy/numpy/pull/7798
    // Because of this, we won't use numpy's logic to generate buffer format
    // strings and will just do it ourselves.
1138
    ssize_t offset = 0;
1139
    std::ostringstream oss;
1140
1141
1142
1143
1144
1145
    // mark the structure as unaligned with '^', because numpy and C++ don't
    // always agree about alignment (particularly for complex), and we're
    // explicitly listing all our padding. This depends on none of the fields
    // overriding the endianness. Putting the ^ in front of individual fields
    // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
    oss << "^T{";
1146
1147
1148
    for (auto& field : ordered_fields) {
        if (field.offset > offset)
            oss << (field.offset - offset) << 'x';
1149
        oss << field.format << ':' << field.name << ':';
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
        offset = field.offset + field.size;
    }
    if (itemsize > offset)
        oss << (itemsize - offset) << 'x';
    oss << '}';
    auto format_str = oss.str();

    // Sanity check: verify that NumPy properly parses our buffer format string
    auto& api = npy_api::get();
    auto arr =  array(buffer_info(nullptr, itemsize, format_str, 1));
    if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
        pybind11_fail("NumPy: invalid buffer descriptor!");

    auto tindex = std::type_index(tinfo);
    numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
    get_internals().direct_conversions[tindex].push_back(direct_converter);
}

1168
1169
1170
template <typename T, typename SFINAE> struct npy_format_descriptor {
    static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");

1171
    static constexpr auto name = make_caster<T>::name;
1172

1173
    static pybind11::dtype dtype() {
1174
        return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
1175
1176
    }

1177
    static std::string format() {
1178
        static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
1179
        return format_str;
1180
1181
    }

1182
1183
    static void register_dtype(any_container<field_descriptor> fields) {
        register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
1184
                                  sizeof(T), &direct_converter);
1185
1186
1187
    }

private:
1188
1189
1190
1191
    static PyObject* dtype_ptr() {
        static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
        return ptr;
    }
1192

1193
1194
1195
    static bool direct_converter(PyObject *obj, void*& value) {
        auto& api = npy_api::get();
        if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
1196
            return false;
1197
        if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
1198
            if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
1199
1200
1201
1202
1203
1204
                value = ((PyVoidScalarObject_Proxy *) obj)->obval;
                return true;
            }
        }
        return false;
    }
1205
1206
};

1207
1208
1209
1210
1211
#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
#else

1212
1213
1214
1215
1216
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name)                                          \
    ::pybind11::detail::field_descriptor {                                                    \
        Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)),                  \
        ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(),           \
        ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
1217
    }
1218

1219
1220
1221
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)

1222
1223
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
peter's avatar
peter committed
1237
#if defined(_MSC_VER) && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
1238
1239
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
1240
#else
1241
1242
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
1243
#endif
1244
1245
1246
1247
1248
1249
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
1250
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
1251
1252
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
1253

1254
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
1255
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1256
1257
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
1258

peter's avatar
peter committed
1259
#if defined(_MSC_VER) && !defined(__clang__)
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
#else
#define PYBIND11_MAP2_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
#endif
#define PYBIND11_MAP2_LIST_NEXT(test, next) \
    PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
    f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
#define PYBIND11_MAP2_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))

#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
1278
1279
        (::std::vector<::pybind11::detail::field_descriptor> \
         {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
1280

1281
1282
#endif // __CLION_IDE__

1283
1284
class common_iterator {
public:
1285
    using container_type = std::vector<ssize_t>;
1286
1287
1288
1289
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
1290

1291
    common_iterator(void* ptr, const container_type& strides, const container_type& shape)
1292
1293
1294
1295
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
1296
            auto s = static_cast<value_type>(shape[i]);
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

1314
template <size_t N> class multi_array_iterator {
1315
public:
1316
    using container_type = std::vector<ssize_t>;
1317

1318
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
1319
                         const container_type &shape)
1320
1321
1322
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

1323
        // Manual copy to avoid conversion warning if using std::copy
1324
        for (size_t i = 0; i < shape.size(); ++i)
1325
            m_shape[i] = shape[i];
1326

1327
        container_type strides(shape.size());
1328
        for (size_t i = 0; i < N; ++i)
1329
1330
1331
1332
1333
1334
1335
1336
1337
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
1338
            } else {
1339
1340
1341
1342
1343
1344
                m_index[i] = 0;
            }
        }
        return *this;
    }

1345
1346
    template <size_t K, class T = void> T* data() const {
        return reinterpret_cast<T*>(m_common_iterator[K].data());
1347
1348
1349
1350
1351
1352
    }

private:

    using common_iter = common_iterator;

1353
    void init_common_iterator(const buffer_info &buffer,
1354
1355
1356
                              const container_type &shape,
                              common_iter &iterator,
                              container_type &strides) {
1357
1358
1359
1360
1361
1362
1363
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
1364
                *strides_iter = *buffer_strides_iter;
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
1379
        for (auto &iter : m_common_iterator)
1380
1381
1382
1383
1384
1385
1386
1387
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

1388
1389
1390
1391
1392
1393
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };

// Populates the shape and number of dimensions for the set of buffers.  Returns a broadcast_trivial
// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
// buffer; returns `non_trivial` otherwise.
1394
template <size_t N>
1395
broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
1396
    ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
1397
1398
1399
        return std::max(res, buf.ndim);
    });

1400
    shape.clear();
1401
    shape.resize((size_t) ndim, 1);
1402

1403
1404
    // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
    // the full size).
1405
1406
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
1407
1408
        auto end = buffers[i].shape.rend();
        for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
1409
1410
1411
1412
1413
1414
1415
            const auto &dim_size_in = *shape_iter;
            auto &dim_size_out = *res_iter;

            // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
            if (dim_size_out == 1)
                dim_size_out = dim_size_in;
            else if (dim_size_in != 1 && dim_size_in != dim_size_out)
1416
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
1417
1418
        }
    }
1419

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
    bool trivial_broadcast_c = true;
    bool trivial_broadcast_f = true;
    for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
        if (buffers[i].size == 1)
            continue;

        // Require the same number of dimensions:
        if (buffers[i].ndim != ndim)
            return broadcast_trivial::non_trivial;

        // Require all dimensions be full-size:
        if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
            return broadcast_trivial::non_trivial;

        // Check for C contiguity (but only if previous inputs were also C contiguous)
        if (trivial_broadcast_c) {
1436
            ssize_t expect_stride = buffers[i].itemsize;
1437
            auto end = buffers[i].shape.crend();
1438
1439
            for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
                    trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
1440
1441
1442
1443
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_c = false;
1444
            }
1445
        }
1446

1447
1448
        // Check for Fortran contiguity (if previous inputs were also F contiguous)
        if (trivial_broadcast_f) {
1449
            ssize_t expect_stride = buffers[i].itemsize;
1450
            auto end = buffers[i].shape.cend();
1451
1452
            for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
                    trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
1453
1454
1455
1456
1457
                if (expect_stride == *stride_iter)
                    expect_stride *= *shape_iter;
                else
                    trivial_broadcast_f = false;
            }
1458
1459
        }
    }
1460
1461
1462
1463
1464

    return
        trivial_broadcast_c ? broadcast_trivial::c_trivial :
        trivial_broadcast_f ? broadcast_trivial::f_trivial :
        broadcast_trivial::non_trivial;
1465
1466
}

1467
1468
1469
1470
1471
1472
1473
template <typename T>
struct vectorize_arg {
    static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
    // The wrapped function gets called with this type:
    using call_type = remove_reference_t<T>;
    // Is this a vectorized argument?
    static constexpr bool vectorize =
1474
        satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value &&
1475
1476
1477
1478
1479
1480
1481
        satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
        (!std::is_reference<T>::value ||
         (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
    // Accept this type: an array for vectorized types, otherwise the type as-is:
    using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
};

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531

// py::vectorize when a return type is present
template <typename Func, typename Return, typename... Args>
struct vectorize_returned_array {
    using Type = array_t<Return>;

    static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
        if (trivial == broadcast_trivial::f_trivial)
            return array_t<Return, array::f_style>(shape);
        else
            return array_t<Return>(shape);
    }

    static Return *mutable_data(Type &array) {
        return array.mutable_data();
    }

    static Return call(Func &f, Args &... args) {
        return f(args...);
    }

    static void call(Return *out, size_t i, Func &f, Args &... args) {
        out[i] = f(args...);
    }
};

// py::vectorize when a return type is not present
template <typename Func, typename... Args>
struct vectorize_returned_array<Func, void, Args...> {
    using Type = none;

    static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
        return none();
    }

    static void *mutable_data(Type &) {
        return nullptr;
    }

    static detail::void_type call(Func &f, Args &... args) {
        f(args...);
        return {};
    }

    static void call(void *, size_t, Func &f, Args &... args) {
        f(args...);
    }
};


1532
1533
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
1534
1535
1536
1537
1538

// NVCC for some reason breaks if NVectorized is private
#ifdef __CUDACC__
public:
#else
1539
private:
1540
1541
#endif

1542
    static constexpr size_t N = sizeof...(Args);
1543
1544
1545
    static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
    static_assert(NVectorized >= 1,
            "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
1546

1547
public:
1548
    template <typename T>
1549
    explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
1550

1551
1552
1553
1554
1555
    object operator()(typename vectorize_arg<Args>::type... args) {
        return run(args...,
                   make_index_sequence<N>(),
                   select_indices<vectorize_arg<Args>::vectorize...>(),
                   make_index_sequence<NVectorized>());
1556
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
1557

1558
1559
1560
private:
    remove_reference_t<Func> f;

1561
1562
1563
1564
    // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
    // when arg_call_types is manually inlined.
    using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
    template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
1565

1566
1567
    using returned_array = vectorize_returned_array<Func, Return, Args...>;

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
    // Runs a vectorized function given arguments tuple and three index sequences:
    //     - Index is the full set of 0 ... (N-1) argument indices;
    //     - VIndex is the subset of argument indices with vectorized parameters, letting us access
    //       vectorized arguments (anything not in this sequence is passed through)
    //     - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
    //       we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
    //       index BIndex in the array).
    template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
            typename vectorize_arg<Args>::type &...args,
            index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {

        // Pointers to values the function was called with; the vectorized ones set here will start
        // out as array_t<T> pointers, but they will be changed them to T pointers before we make
        // call the wrapped function.  Non-vectorized pointers are left as-is.
        std::array<void *, N> params{{ &args... }};

        // The array of `buffer_info`s of vectorized arguments:
        std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
Wenzel Jakob's avatar
Wenzel Jakob committed
1586
1587

        /* Determine dimensions parameters of output array */
1588
1589
1590
        ssize_t nd = 0;
        std::vector<ssize_t> shape(0);
        auto trivial = broadcast(buffers, nd, shape);
1591
        auto ndim = (size_t) nd;
1592

1593
1594
1595
1596
1597
1598
        size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());

        // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
        // not wrapped in an array).
        if (size == 1 && ndim == 0) {
            PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
1599
            return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
Wenzel Jakob's avatar
Wenzel Jakob committed
1600
1601
        }

1602
        auto result = returned_array::create(trivial, shape);
Wenzel Jakob's avatar
Wenzel Jakob committed
1603

Henry Schreiner's avatar
Henry Schreiner committed
1604
        if (size == 0) return std::move(result);
1605

1606
        /* Call the function */
1607
        auto mutable_data = returned_array::mutable_data(result);
1608
        if (trivial == broadcast_trivial::non_trivial)
1609
            apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
1610
        else
1611
            apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
1612

Henry Schreiner's avatar
Henry Schreiner committed
1613
        return std::move(result);
1614
    }
1615

1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
                       std::array<void *, N> &params,
                       Return *out,
                       size_t size,
                       index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {

        // Initialize an array of mutable byte references and sizes with references set to the
        // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
        // (except for singletons, which get an increment of 0).
        std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
            std::pair<unsigned char *&, const size_t>(
                    reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
                    buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
            )...
        }};

        for (size_t i = 0; i < size; ++i) {
1634
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
1635
1636
1637
1638
1639
1640
1641
            for (auto &x : vecparams) x.first += x.second;
        }
    }

    template <size_t... Index, size_t... VIndex, size_t... BIndex>
    void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
                         std::array<void *, N> &params,
1642
1643
1644
                         Return *out,
                         size_t size,
                         const std::vector<ssize_t> &output_shape,
1645
                         index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
1646

1647
        multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
1648

1649
        for (size_t i = 0; i < size; ++i, ++input_iter) {
1650
1651
1652
            PYBIND11_EXPAND_SIDE_EFFECTS((
                params[VIndex] = input_iter.template data<BIndex>()
            ));
1653
            returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
1654
1655
        }
    }
1656
1657
};

1658
1659
1660
1661
1662
1663
template <typename Func, typename Return, typename... Args>
vectorize_helper<Func, Return, Args...>
vectorize_extractor(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
}

1664
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
1665
    static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
1666
1667
};

1668
PYBIND11_NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
1669

1670
// Vanilla pointer vectorizer:
1671
template <typename Return, typename... Args>
1672
detail::vectorize_helper<Return (*)(Args...), Return, Args...>
1673
vectorize(Return (*f) (Args ...)) {
1674
    return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
1675
1676
}

1677
// lambda vectorizer:
1678
template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
1679
auto vectorize(Func &&f) -> decltype(
1680
1681
        detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
    return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
1682
1683
1684
1685
1686
1687
1688
1689
1690
}

// Vectorize a class method (non-const):
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...)) {
    return Helper(std::mem_fn(f));
}

1691
// Vectorize a class method (const):
1692
1693
1694
1695
template <typename Return, typename Class, typename... Args,
          typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
Helper vectorize(Return (Class::*f)(Args...) const) {
    return Helper(std::mem_fn(f));
Wenzel Jakob's avatar
Wenzel Jakob committed
1696
1697
}

1698
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
Wenzel Jakob's avatar
Wenzel Jakob committed
1699
1700
1701
1702

#if defined(_MSC_VER)
#pragma warning(pop)
#endif