numpy.h 26.2 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <array>
17
#include <cstdlib>
18
#include <cstring>
19
#include <sstream>
20
#include <initializer_list>
21

Wenzel Jakob's avatar
Wenzel Jakob committed
22
23
24
25
26
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif

27
NAMESPACE_BEGIN(pybind11)
28
29
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
Ivan Smirnov's avatar
Ivan Smirnov committed
30
template <typename type> struct is_pod_struct;
Wenzel Jakob's avatar
Wenzel Jakob committed
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
struct npy_api {
    enum constants {
        NPY_C_CONTIGUOUS_ = 0x0001,
        NPY_F_CONTIGUOUS_ = 0x0002,
        NPY_ARRAY_FORCECAST_ = 0x0010,
        NPY_ENSURE_ARRAY_ = 0x0040,
        NPY_BOOL_ = 0,
        NPY_BYTE_, NPY_UBYTE_,
        NPY_SHORT_, NPY_USHORT_,
        NPY_INT_, NPY_UINT_,
        NPY_LONG_, NPY_ULONG_,
        NPY_LONGLONG_, NPY_ULONGLONG_,
        NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
        NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
        NPY_OBJECT_ = 17,
        NPY_STRING_, NPY_UNICODE_, NPY_VOID_
    };

    static npy_api& get() {
        static npy_api api = lookup();
        return api;
    }

55
56
57
58
59
60
    bool PyArray_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
    }
    bool PyArrayDescr_Check_(PyObject *obj) const {
        return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
    }
61
62
63
64
65
66
67
68

    PyObject *(*PyArray_DescrFromType_)(int);
    PyObject *(*PyArray_NewFromDescr_)
        (PyTypeObject *, PyObject *, int, Py_intptr_t *,
         Py_intptr_t *, void *, int, PyObject *);
    PyObject *(*PyArray_DescrNewFromType_)(int);
    PyObject *(*PyArray_NewCopy_)(PyObject *, int);
    PyTypeObject *PyArray_Type_;
69
    PyTypeObject *PyArrayDescr_Type_;
70
71
72
73
74
75
76
77
    PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
    int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
    bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
    int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                             Py_ssize_t *, PyObject **, PyObject *);
private:
    enum functions {
        API_PyArray_Type = 2,
78
        API_PyArrayDescr_Type = 3,
79
80
81
82
83
84
85
86
87
88
89
90
91
        API_PyArray_DescrFromType = 45,
        API_PyArray_FromAny = 69,
        API_PyArray_NewCopy = 85,
        API_PyArray_NewFromDescr = 94,
        API_PyArray_DescrNewFromType = 9,
        API_PyArray_DescrConverter = 174,
        API_PyArray_EquivTypes = 182,
        API_PyArray_GetArrayParamsFromObject = 278,
    };

    static npy_api lookup() {
        module m = module::import("numpy.core.multiarray");
        object c = (object) m.attr("_ARRAY_API");
92
#if PY_MAJOR_VERSION >= 3
93
        void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
94
#else
95
        void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
96
#endif
97
        npy_api api;
98
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
99
        DECL_NPY_API(PyArray_Type);
100
        DECL_NPY_API(PyArrayDescr_Type);
101
102
103
104
105
106
107
108
        DECL_NPY_API(PyArray_DescrFromType);
        DECL_NPY_API(PyArray_FromAny);
        DECL_NPY_API(PyArray_NewCopy);
        DECL_NPY_API(PyArray_NewFromDescr);
        DECL_NPY_API(PyArray_DescrNewFromType);
        DECL_NPY_API(PyArray_DescrConverter);
        DECL_NPY_API(PyArray_EquivTypes);
        DECL_NPY_API(PyArray_GetArrayParamsFromObject);
109
#undef DECL_NPY_API
110
111
112
113
        return api;
    }
};
}
Wenzel Jakob's avatar
Wenzel Jakob committed
114

115
class dtype : public object {
116
public:
117
    PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
Wenzel Jakob's avatar
Wenzel Jakob committed
118

119
120
121
122
    dtype(const buffer_info &info) {
        dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
        m_ptr = descr.strip_padding().release().ptr();
    }
123

124
125
    dtype(std::string format) {
        m_ptr = from_args(pybind11::str(format)).release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
126
127
    }

128
129
130
131
132
133
134
    static dtype from_args(object args) {
        // This is essentially the same as calling np.dtype() constructor in Python
        PyObject *ptr = nullptr;
        if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
            pybind11_fail("NumPy: failed to create structured dtype");
        return object(ptr, false);
    }
135

136
137
138
    template <typename T> static dtype of() {
        return detail::npy_format_descriptor<T>::dtype();
    }
139

140
141
    size_t itemsize() const {
        return (size_t) attr("itemsize").cast<int_>();
Wenzel Jakob's avatar
Wenzel Jakob committed
142
143
    }

144
145
146
147
148
149
150
151
152
153
154
155
156
    bool has_fields() const {
        return attr("fields").cast<object>().ptr() != Py_None;
    }

    std::string kind() const {
        return (std::string) attr("kind").cast<pybind11::str>();
    }

private:
    static object& _dtype_from_pep3118() {
        static object obj = module::import("numpy.core._internal").attr("_dtype_from_pep3118");
        return obj;
    }
157

158
    dtype strip_padding() {
159
160
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
161
        auto fields = attr("fields").cast<object>();
162
        if (fields.ptr() == Py_None)
163
            return *this;
164
165
166
167
168
169
170
171

        struct field_descr { pybind11::str name; object format; int_ offset; };
        std::vector<field_descr> field_descriptors;

        auto items = fields.attr("items").cast<object>();
        for (auto field : items()) {
            auto spec = object(field, true).cast<tuple>();
            auto name = spec[0].cast<pybind11::str>();
172
            auto format = spec[1].cast<tuple>()[0].cast<dtype>();
173
            auto offset = spec[1].cast<tuple>()[1].cast<int_>();
174
            if (!len(name) && format.kind() == "V")
175
                continue;
176
            field_descriptors.push_back({name, format.strip_padding(), offset});
177
178
179
180
181
182
183
184
185
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
                      return (int) a.offset < (int) b.offset;
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
186
            names.append(descr.name); formats.append(descr.format); offsets.append(descr.offset);
187
188
189
        }
        auto args = dict();
        args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
190
191
192
193
        args["itemsize"] = (int_) itemsize();
        return dtype::from_args(args);
    }
};
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class array : public buffer {
public:
    PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)

    enum {
        c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
        f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
        forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
    };

    template <typename Type> array(size_t size, const Type *ptr) {
        auto& api = detail::npy_api::get();
        auto descr = pybind11::dtype::of<Type>().release().ptr();
        Py_intptr_t shape = (Py_intptr_t) size;
        object tmp = object(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
        if (ptr)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
        m_ptr = tmp.release().ptr();
    }

    array(const buffer_info &info) {
        auto& api = detail::npy_api::get();
        auto descr = pybind11::dtype(info).release().ptr();
        object tmp(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
            (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
        if (info.ptr)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
        m_ptr = tmp.release().ptr();
229
    }
230
231
232
233
234
235
236

    pybind11::dtype dtype() {
        return attr("dtype").cast<pybind11::dtype>();
    }

protected:
    template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
Wenzel Jakob's avatar
Wenzel Jakob committed
237
238
};

239
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
240
public:
241
    PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
242
    array_t() : array() { }
Johan Mabille's avatar
Johan Mabille committed
243
    array_t(const buffer_info& info) : array(info) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
244
    static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
245
    static PyObject *ensure(PyObject *ptr) {
246
247
        if (ptr == nullptr)
            return nullptr;
248
        auto& api = detail::npy_api::get();
249
        PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
250
                                                detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
251
252
        if (!result)
            PyErr_Clear();
253
254
        Py_DECREF(ptr);
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
255
256
257
    }
};

258
259
template <typename T>
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
260
261
262
263
264
265
266
267
    static const char *format() { return detail::npy_format_descriptor<T>::format(); }
};

template <size_t N> struct format_descriptor<char[N]> {
    static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
};
template <size_t N> struct format_descriptor<std::array<char, N>> {
    static const char *format() { PYBIND11_DESCR s = detail::_<N>() + detail::_("s"); return s.text(); }
268
269
};

270
NAMESPACE_BEGIN(detail)
Ivan Smirnov's avatar
Ivan Smirnov committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
template <typename T> struct is_std_array : std::false_type { };
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };

template <typename T>
struct is_pod_struct {
    enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
           !std::is_array<T>::value &&
           !is_std_array<T>::value &&
           !std::is_integral<T>::value &&
           !std::is_same<T, float>::value &&
           !std::is_same<T, double>::value &&
           !std::is_same<T, bool>::value &&
           !std::is_same<T, std::complex<float>>::value &&
           !std::is_same<T, std::complex<double>>::value };
};
286

287
288
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
private:
Johan Mabille's avatar
Johan Mabille committed
289
    constexpr static const int values[8] = {
290
291
        npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_,    npy_api::NPY_USHORT_,
        npy_api::NPY_INT_,  npy_api::NPY_UINT_,  npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_ };
292
public:
293
    enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
294
    static pybind11::dtype dtype() {
295
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
296
297
            return object(ptr, true);
        pybind11_fail("Unsupported buffer format!");
298
    }
299
300
301
302
    template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0>
    static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
    template <typename T2 = T, typename std::enable_if<!std::is_signed<T2>::value, int>::type = 0>
    static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
303
304
305
306
};
template <typename T> constexpr const int npy_format_descriptor<
    T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];

307
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
308
    enum { value = npy_api::NumPyName }; \
309
    static pybind11::dtype dtype() { \
310
        if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) \
311
312
            return object(ptr, true); \
        pybind11_fail("Unsupported buffer format!"); \
313
    } \
314
    static PYBIND11_DESCR name() { return _(Name); } }
315
316
317
318
319
DECL_FMT(float, NPY_FLOAT_, "float32");
DECL_FMT(double, NPY_DOUBLE_, "float64");
DECL_FMT(bool, NPY_BOOL_, "bool");
DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
320
321
#undef DECL_FMT

322
323
#define DECL_CHAR_FMT \
    static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
324
    static pybind11::dtype dtype() { \
325
        PYBIND11_DESCR fmt = _("S") + _<N>(); \
326
        return pybind11::dtype::from_args(pybind11::str(fmt.text())); \
327
328
329
330
331
332
    } \
    static const char *format() { PYBIND11_DESCR s = _<N>() + _("s"); return s.text(); }
template <size_t N> struct npy_format_descriptor<char[N]> { DECL_CHAR_FMT };
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { DECL_CHAR_FMT };
#undef DECL_CHAR_FMT

333
334
struct field_descriptor {
    const char *name;
335
    size_t offset;
336
337
    size_t size;
    const char *format;
338
    dtype descr;
339
340
};

341
template <typename T>
Ivan Smirnov's avatar
Ivan Smirnov committed
342
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
343
    static PYBIND11_DESCR name() { return _("struct"); }
344

345
    static pybind11::dtype dtype() {
346
        if (!dtype_())
347
            pybind11_fail("NumPy: unsupported buffer format!");
348
        return object(dtype_(), true);
349
350
    }

351
    static const char* format() {
352
353
        if (!dtype_())
            pybind11_fail("NumPy: unsupported buffer format!");
354
        return format_().c_str();
355
356
357
    }

    static void register_dtype(std::initializer_list<field_descriptor> fields) {
358
359
        auto args = dict();
        list names { }, offsets { }, formats { };
360
361
362
        for (auto field : fields) {
            if (!field.descr)
                pybind11_fail("NumPy: unsupported field dtype");
363
364
            names.append(str(field.name));
            offsets.append(int_(field.offset));
365
            formats.append(field.descr);
366
        }
367
        args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
368
        args["itemsize"] = int_(sizeof(T));
369
        dtype_() = pybind11::dtype::from_args(args).release().ptr();
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        // There is an existing bug in NumPy (as of v1.11): trailing bytes are
        // not encoded explicitly into the format string. This will supposedly
        // get fixed in v1.12; for further details, see these:
        // - https://github.com/numpy/numpy/issues/7797
        // - https://github.com/numpy/numpy/pull/7798
        // Because of this, we won't use numpy's logic to generate buffer format
        // strings and will just do it ourselves.
        std::vector<field_descriptor> ordered_fields(fields);
        std::sort(ordered_fields.begin(), ordered_fields.end(),
                  [](const field_descriptor& a, const field_descriptor &b) {
                      return a.offset < b.offset;
                  });
        size_t offset = 0;
        std::ostringstream oss;
        oss << "T{";
        for (auto& field : ordered_fields) {
            if (field.offset > offset)
                oss << (field.offset - offset) << 'x';
            // note that '=' is required to cover the case of unaligned fields
            oss << '=' << field.format << ':' << field.name << ':';
            offset = field.offset + field.size;
Ivan Smirnov's avatar
Ivan Smirnov committed
392
        }
393
394
395
        if (sizeof(T) > offset)
            oss << (sizeof(T) - offset) << 'x';
        oss << '}';
396
        format_() = oss.str();
397
398

        // Sanity check: verify that NumPy properly parses our buffer format string
399
        auto& api = npy_api::get();
400
        auto arr =  array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
401
        if (!api.PyArray_EquivTypes_(dtype_(), arr.dtype().ptr()))
402
            pybind11_fail("NumPy: invalid buffer descriptor!");
403
404
405
    }

private:
406
    static inline PyObject*& dtype_() { static PyObject *ptr = nullptr; return ptr; }
407
    static inline std::string& format_() { static std::string s; return s; }
408
409
};

410
// Extract name, offset and format descriptor for a struct field
411
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
412
    ::pybind11::detail::field_descriptor { \
413
414
        #Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
        ::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
415
        ::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
416
    }
417
418
419

// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
420
421
422
423
424
425
426
427
428
429
430
431
432
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
433
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
434
435
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
436
#else
437
438
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
439
#endif
440
441
442
443
444
445
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
446
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
447
448
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
449

450
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
451
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
452
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
453

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
474

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

498
template <size_t N> class multi_array_iterator {
499
500
501
public:
    using container_type = std::vector<size_t>;

502
503
504
505
506
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

507
        // Manual copy to avoid conversion warning if using std::copy
508
        for (size_t i = 0; i < shape.size(); ++i)
509
510
511
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
512
        for (size_t i = 0; i < N; ++i)
513
514
515
516
517
518
519
520
521
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
522
            } else {
523
524
525
526
527
528
                m_index[i] = 0;
            }
        }
        return *this;
    }

529
    template <size_t K, class T> const T& data() const {
530
531
532
533
534
535
536
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

537
538
539
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
540
541
542
543
544
545
546
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
547
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
562
        for (auto &iter : m_common_iterator)
563
564
565
566
567
568
569
570
571
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

template <size_t N>
572
573
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
574
575
576
        return std::max(res, buf.ndim);
    });

577
    shape = std::vector<size_t>(ndim, 1);
578
579
580
581
    bool trivial_broadcast = true;
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
582
583
584
585
        for (auto shape_iter = buffers[i].shape.rbegin();
             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {

            if (*res_iter == 1)
586
                *res_iter = *shape_iter;
587
            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
588
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
589

590
591
592
593
594
595
596
            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
        }
        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
    }
    return trivial_broadcast;
}

597
598
599
600
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

601
602
    template <typename T>
    vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
603

604
    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
605
606
        return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
607

608
    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
609
        /* Request buffers from all parameters */
610
        const size_t N = sizeof...(Args);
611

Wenzel Jakob's avatar
Wenzel Jakob committed
612
613
614
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
615
        size_t ndim = 0;
616
617
        std::vector<size_t> shape(0);
        bool trivial_broadcast = broadcast(buffers, ndim, shape);
618

619
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
620
621
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
622
            strides[ndim-1] = sizeof(Return);
623
            for (size_t i = ndim - 1; i > 0; --i) {
624
625
626
627
                strides[i - 1] = strides[i] * shape[i];
                size *= shape[i];
            }
            size *= shape[0];
Wenzel Jakob's avatar
Wenzel Jakob committed
628
629
        }

630
        if (size == 1)
631
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
632

633
        array result(buffer_info(nullptr, sizeof(Return),
634
            format_descriptor<Return>::format(),
Wenzel Jakob's avatar
Wenzel Jakob committed
635
            ndim, shape, strides));
636
637
638
639

        buffer_info buf = result.request();
        Return *output = (Return *) buf.ptr;

640
        if (trivial_broadcast) {
641
642
643
            /* Call the function */
            for (size_t i=0; i<size; ++i) {
                output[i] = f((buffers[Index].size == 1
644
645
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);
646
            }
647
        } else {
648
649
            apply_broadcast<N, Index...>(buffers, buf, index);
        }
650
651

        return result;
652
    }
653
654

    template <size_t N, size_t... Index>
655
656
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
657
658
659
660
661
662
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

663
664
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
665
666
667
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
668
669
};

670
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
671
    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); }
672
673
};

674
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
675

676
677
678
template <typename Func, typename Return, typename... Args>
detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
679
680
}

681
682
683
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
684
685
686
687
688
689
690
691
}

template <typename func> auto vectorize(func &&f) -> decltype(
        vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
    return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
                   &std::remove_reference<func>::type::operator())>::type *) nullptr);
}

692
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
693
694
695
696

#if defined(_MSC_VER)
#pragma warning(pop)
#endif