numpy.h 24.9 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
2
    pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
Wenzel Jakob's avatar
Wenzel Jakob committed
3

4
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
5
6
7
8
9
10
11

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

12
13
#include "pybind11.h"
#include "complex.h"
14
15
#include <numeric>
#include <algorithm>
16
#include <cstdlib>
17
#include <cstring>
18
#include <sstream>
19
#include <initializer_list>
20

Wenzel Jakob's avatar
Wenzel Jakob committed
21
22
23
24
25
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif

26
NAMESPACE_BEGIN(pybind11)
27
28
29
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };

30
31
object fix_dtype(object);

32
33
34
35
36
template <typename T>
struct is_pod_struct {
    enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
           !std::is_integral<T>::value &&
           !std::is_same<T, float>::value &&
37
           !std::is_same<T, double>::value &&
38
39
40
41
42
           !std::is_same<T, bool>::value &&
           !std::is_same<T, std::complex<float>>::value &&
           !std::is_same<T, std::complex<double>>::value };
};
}
Wenzel Jakob's avatar
Wenzel Jakob committed
43

Wenzel Jakob's avatar
Wenzel Jakob committed
44
class array : public buffer {
Wenzel Jakob's avatar
Wenzel Jakob committed
45
public:
Wenzel Jakob's avatar
Wenzel Jakob committed
46
47
48
49
50
51
52
    struct API {
        enum Entries {
            API_PyArray_Type = 2,
            API_PyArray_DescrFromType = 45,
            API_PyArray_FromAny = 69,
            API_PyArray_NewCopy = 85,
            API_PyArray_NewFromDescr = 94,
53
            API_PyArray_DescrNewFromType = 9,
54
            API_PyArray_DescrConverter = 174,
55
            API_PyArray_EquivTypes = 182,
56
            API_PyArray_GetArrayParamsFromObject = 278,
57
58
59
60
61
62
63
64
65
66
67
68

            NPY_C_CONTIGUOUS_ = 0x0001,
            NPY_F_CONTIGUOUS_ = 0x0002,
            NPY_ARRAY_FORCECAST_ = 0x0010,
            NPY_ENSURE_ARRAY_ = 0x0040,
            NPY_BOOL_ = 0,
            NPY_BYTE_, NPY_UBYTE_,
            NPY_SHORT_, NPY_USHORT_,
            NPY_INT_, NPY_UINT_,
            NPY_LONG_, NPY_ULONG_,
            NPY_LONGLONG_, NPY_ULONGLONG_,
            NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
69
70
71
            NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
            NPY_OBJECT_ = 17,
            NPY_STRING_, NPY_UNICODE_, NPY_VOID_
Wenzel Jakob's avatar
Wenzel Jakob committed
72
73
74
        };

        static API lookup() {
75
76
            module m = module::import("numpy.core.multiarray");
            object c = (object) m.attr("_ARRAY_API");
77
#if PY_MAJOR_VERSION >= 3
78
            void **api_ptr = (void **) (c ? PyCapsule_GetPointer(c.ptr(), NULL) : nullptr);
79
#else
80
            void **api_ptr = (void **) (c ? PyCObject_AsVoidPtr(c.ptr()) : nullptr);
81
#endif
Wenzel Jakob's avatar
Wenzel Jakob committed
82
            API api;
83
84
85
86
87
88
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
            DECL_NPY_API(PyArray_Type);
            DECL_NPY_API(PyArray_DescrFromType);
            DECL_NPY_API(PyArray_FromAny);
            DECL_NPY_API(PyArray_NewCopy);
            DECL_NPY_API(PyArray_NewFromDescr);
89
            DECL_NPY_API(PyArray_DescrNewFromType);
90
            DECL_NPY_API(PyArray_DescrConverter);
91
            DECL_NPY_API(PyArray_EquivTypes);
92
            DECL_NPY_API(PyArray_GetArrayParamsFromObject);
93
#undef DECL_NPY_API
Wenzel Jakob's avatar
Wenzel Jakob committed
94
95
96
            return api;
        }

97
        bool PyArray_Check_(PyObject *obj) const { return (bool) PyObject_TypeCheck(obj, PyArray_Type_); }
Wenzel Jakob's avatar
Wenzel Jakob committed
98

99
100
        PyObject *(*PyArray_DescrFromType_)(int);
        PyObject *(*PyArray_NewFromDescr_)
Wenzel Jakob's avatar
Wenzel Jakob committed
101
102
            (PyTypeObject *, PyObject *, int, Py_intptr_t *,
             Py_intptr_t *, void *, int, PyObject *);
103
        PyObject *(*PyArray_DescrNewFromType_)(int);
104
105
106
        PyObject *(*PyArray_NewCopy_)(PyObject *, int);
        PyTypeObject *PyArray_Type_;
        PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
107
        int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
108
        bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
109
110
        int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
                                                 Py_ssize_t *, PyObject **, PyObject *);
Wenzel Jakob's avatar
Wenzel Jakob committed
111
    };
Wenzel Jakob's avatar
Wenzel Jakob committed
112

113
    PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
Wenzel Jakob's avatar
Wenzel Jakob committed
114

115
116
    enum {
        c_style = API::NPY_C_CONTIGUOUS_,
117
118
        f_style = API::NPY_F_CONTIGUOUS_,
        forcecast = API::NPY_ARRAY_FORCECAST_
119
120
    };

Wenzel Jakob's avatar
Wenzel Jakob committed
121
122
    template <typename Type> array(size_t size, const Type *ptr) {
        API& api = lookup_api();
123
        PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
124
        Py_intptr_t shape = (Py_intptr_t) size;
125
126
127
        object tmp = object(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
        if (!tmp)
Wenzel Jakob's avatar
Wenzel Jakob committed
128
            pybind11_fail("NumPy: unable to create array!");
129
130
        if (ptr)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
Wenzel Jakob's avatar
Wenzel Jakob committed
131
        m_ptr = tmp.release().ptr();
Wenzel Jakob's avatar
Wenzel Jakob committed
132
133
134
    }

    array(const buffer_info &info) {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        auto& api = lookup_api();

        // _dtype_from_pep3118 returns dtypes with padding fields in, however the array
        // constructor seems to then consume them, so we don't need to strip them ourselves
        auto numpy_internal = module::import("numpy.core._internal");
        auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
        auto dtype = dtype_from_fmt(pybind11::str(info.format));
        auto dtype2 = strip_padding_fields(dtype);

        object tmp(api.PyArray_NewFromDescr_(
            api.PyArray_Type_, dtype2.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
            (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
        if (!tmp)
            pybind11_fail("NumPy: unable to create array!");
        if (info.ptr)
            tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
        m_ptr = tmp.release().ptr();
        auto d = (object) this->attr("dtype");
Wenzel Jakob's avatar
Wenzel Jakob committed
153
154
    }

155
// protected:
Wenzel Jakob's avatar
Wenzel Jakob committed
156
157
158
159
    static API &lookup_api() {
        static API api = API::lookup();
        return api;
    }
160
161

    template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    static object strip_padding_fields(object dtype) {
        // Recursively strip all void fields with empty names that are generated for
        // padding fields (as of NumPy v1.11).
        auto fields = dtype.attr("fields").cast<object>();
        if (fields.ptr() == Py_None)
            return dtype;

        struct field_descr { pybind11::str name; object format; int_ offset; };
        std::vector<field_descr> field_descriptors;

        auto items = fields.attr("items").cast<object>();
        for (auto field : items()) {
            auto spec = object(field, true).cast<tuple>();
            auto name = spec[0].cast<pybind11::str>();
            auto format = spec[1].cast<tuple>()[0].cast<object>();
            auto offset = spec[1].cast<tuple>()[1].cast<int_>();
            if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
                    continue;
            field_descriptors.push_back({name, strip_padding_fields(format), offset});
        }

        std::sort(field_descriptors.begin(), field_descriptors.end(),
                  [](const field_descr& a, const field_descr& b) {
                      return (int) a.offset < (int) b.offset;
                  });

        list names, formats, offsets;
        for (auto& descr : field_descriptors) {
            names.append(descr.name);
            formats.append(descr.format);
            offsets.append(descr.offset);
        }
        auto args = dict();
        args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
        args["itemsize"] = dtype.attr("itemsize").cast<int_>();

        PyObject *descr = nullptr;
        if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
            pybind11_fail("NumPy: failed to create structured dtype");
        return object(descr, false);
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
204
205
};

206
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
Wenzel Jakob's avatar
Wenzel Jakob committed
207
public:
208
    PYBIND11_OBJECT_CVT(array_t, array, is_non_null, m_ptr = ensure(m_ptr));
209
    array_t() : array() { }
Johan Mabille's avatar
Johan Mabille committed
210
    array_t(const buffer_info& info) : array(info) {}
Wenzel Jakob's avatar
Wenzel Jakob committed
211
    static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
212
    static PyObject *ensure(PyObject *ptr) {
213
214
        if (ptr == nullptr)
            return nullptr;
Wenzel Jakob's avatar
Wenzel Jakob committed
215
        API &api = lookup_api();
216
        PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
217
218
        PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0,
                                                API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
219
220
        if (!result)
            PyErr_Clear();
221
222
        Py_DECREF(ptr);
        return result;
Wenzel Jakob's avatar
Wenzel Jakob committed
223
224
225
    }
};

226
227
template <typename T>
struct format_descriptor<T, typename std::enable_if<detail::is_pod_struct<T>::value>::type> {
228
229
    static const char *format() {
        return detail::npy_format_descriptor<T>::format();
230
231
232
    }
};

233
234
template <typename T>
object dtype_of() {
235
    return detail::npy_format_descriptor<T>::dtype();
236
237
}

238
239
NAMESPACE_BEGIN(detail)

240
241
template <typename T> struct npy_format_descriptor<T, typename std::enable_if<std::is_integral<T>::value>::type> {
private:
Johan Mabille's avatar
Johan Mabille committed
242
    constexpr static const int values[8] = {
243
244
245
        array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_,    array::API::NPY_USHORT_,
        array::API::NPY_INT_,  array::API::NPY_UINT_,  array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
public:
246
247
248
249
250
    enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
    static object dtype() {
        if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value))
            return object(ptr, true);
        pybind11_fail("Unsupported buffer format!");
251
    }
252
253
254
255
    template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0>
    static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
    template <typename T2 = T, typename std::enable_if<!std::is_signed<T2>::value, int>::type = 0>
    static PYBIND11_DESCR name() { return _("uint") + _<sizeof(T)*8>(); }
256
257
258
259
};
template <typename T> constexpr const int npy_format_descriptor<
    T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];

260
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
261
262
263
264
265
    enum { value = array::API::NumPyName }; \
    static object dtype() { \
        if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \
            return object(ptr, true); \
        pybind11_fail("Unsupported buffer format!"); \
266
    } \
267
    static PYBIND11_DESCR name() { return _(Name); } }
268
269
270
271
272
DECL_FMT(float, NPY_FLOAT_, "float32");
DECL_FMT(double, NPY_DOUBLE_, "float64");
DECL_FMT(bool, NPY_BOOL_, "bool");
DECL_FMT(std::complex<float>, NPY_CFLOAT_, "complex64");
DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
273
274
#undef DECL_FMT

275
276
struct field_descriptor {
    const char *name;
277
    size_t offset;
278
279
    size_t size;
    const char *format;
280
    object descr;
281
282
};

283

284
template <typename T>
Ivan Smirnov's avatar
Ivan Smirnov committed
285
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
286
287
    static PYBIND11_DESCR name() { return _("user-defined"); }

288
289
    static object dtype() {
        if (!dtype_())
290
            pybind11_fail("NumPy: unsupported buffer format!");
291
        return object(dtype_(), true);
292
293
    }

294
    static const char* format() {
295
296
        if (!dtype_())
            pybind11_fail("NumPy: unsupported buffer format!");
297
        return format_();
298
299
300
    }

    static void register_dtype(std::initializer_list<field_descriptor> fields) {
301
        auto& api = array::lookup_api();
302
303
        auto args = dict();
        list names { }, offsets { }, formats { };
304
305
306
        for (auto field : fields) {
            if (!field.descr)
                pybind11_fail("NumPy: unsupported field dtype");
307
308
            names.append(str(field.name));
            offsets.append(int_(field.offset));
309
            formats.append(field.descr);
310
        }
311
        args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
312
        args["itemsize"] = int_(sizeof(T));
313
314
        // This is essentially the same as calling np.dtype() constructor in Python and passing
        // it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
315
        if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
316
            pybind11_fail("NumPy: failed to create structured dtype");
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

        // There is an existing bug in NumPy (as of v1.11): trailing bytes are
        // not encoded explicitly into the format string. This will supposedly
        // get fixed in v1.12; for further details, see these:
        // - https://github.com/numpy/numpy/issues/7797
        // - https://github.com/numpy/numpy/pull/7798
        // Because of this, we won't use numpy's logic to generate buffer format
        // strings and will just do it ourselves.
        std::vector<field_descriptor> ordered_fields(fields);
        std::sort(ordered_fields.begin(), ordered_fields.end(),
                  [](const field_descriptor& a, const field_descriptor &b) {
                      return a.offset < b.offset;
                  });
        size_t offset = 0;
        std::ostringstream oss;
        oss << "T{";
        for (auto& field : ordered_fields) {
            if (field.offset > offset)
                oss << (field.offset - offset) << 'x';
            // note that '=' is required to cover the case of unaligned fields
            oss << '=' << field.format << ':' << field.name << ':';
            offset = field.offset + field.size;
Ivan Smirnov's avatar
Ivan Smirnov committed
339
        }
340
341
342
343
344
345
346
347
348
349
350
351
        if (sizeof(T) > offset)
            oss << (sizeof(T) - offset) << 'x';
        oss << '}';
        std::strncpy(format_(), oss.str().c_str(), 4096);

        // Sanity check: verify that NumPy properly parses our buffer format string
        auto arr =  array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
        auto dtype = (object) arr.attr("dtype");
        auto fixed_dtype = dtype;
        // auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
        // if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
        //     pybind11_fail("NumPy: invalid buffer descriptor!");
352
353
354
    }

private:
355
    static inline PyObject*& dtype_() { static PyObject *ptr = nullptr; return ptr; }
356
    static inline char* format_() { static char s[4096]; return s; }
357
358
};

359
// Extract name, offset and format descriptor for a struct field
360
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
361
    ::pybind11::detail::field_descriptor { \
362
363
        #Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
        ::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
364
        ::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
365
    }
366
367
368

// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
// (C) William Swanson, Paul Fultz
369
370
371
372
373
374
375
376
377
378
379
380
381
#define PYBIND11_EVAL0(...) __VA_ARGS__
#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
#define PYBIND11_EVAL(...)  PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
#define PYBIND11_MAP_END(...)
#define PYBIND11_MAP_OUT
#define PYBIND11_MAP_COMMA ,
#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
#define PYBIND11_MAP_NEXT(test, next)  PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
382
#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
383
384
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
385
#else
386
387
#define PYBIND11_MAP_LIST_NEXT1(test, next) \
    PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
388
#endif
389
390
391
392
393
394
#define PYBIND11_MAP_LIST_NEXT(test, next) \
    PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
    f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
395
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
396
397
#define PYBIND11_MAP_LIST(f, t, ...) \
    PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
398

399
#define PYBIND11_NUMPY_DTYPE(Type, ...) \
400
    ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
401
        ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
template  <class T>
using array_iterator = typename std::add_pointer<T>::type;

template <class T>
array_iterator<T> array_begin(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
}

template <class T>
array_iterator<T> array_end(const buffer_info& buffer) {
    return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
}

class common_iterator {
public:
    using container_type = std::vector<size_t>;
    using value_type = container_type::value_type;
    using size_type = container_type::size_type;

    common_iterator() : p_ptr(0), m_strides() {}
423

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
        : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
        m_strides.back() = static_cast<value_type>(strides.back());
        for (size_type i = m_strides.size() - 1; i != 0; --i) {
            size_type j = i - 1;
            value_type s = static_cast<value_type>(shape[i]);
            m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
        }
    }

    void increment(size_type dim) {
        p_ptr += m_strides[dim];
    }

    void* data() const {
        return p_ptr;
    }

private:
    char* p_ptr;
    container_type m_strides;
};

447
template <size_t N> class multi_array_iterator {
448
449
450
public:
    using container_type = std::vector<size_t>;

451
452
453
454
455
    multi_array_iterator(const std::array<buffer_info, N> &buffers,
                         const std::vector<size_t> &shape)
        : m_shape(shape.size()), m_index(shape.size(), 0),
          m_common_iterator() {

456
        // Manual copy to avoid conversion warning if using std::copy
457
        for (size_t i = 0; i < shape.size(); ++i)
458
459
460
            m_shape[i] = static_cast<container_type::value_type>(shape[i]);

        container_type strides(shape.size());
461
        for (size_t i = 0; i < N; ++i)
462
463
464
465
466
467
468
469
470
            init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
    }

    multi_array_iterator& operator++() {
        for (size_t j = m_index.size(); j != 0; --j) {
            size_t i = j - 1;
            if (++m_index[i] != m_shape[i]) {
                increment_common_iterator(i);
                break;
471
            } else {
472
473
474
475
476
477
                m_index[i] = 0;
            }
        }
        return *this;
    }

478
    template <size_t K, class T> const T& data() const {
479
480
481
482
483
484
485
        return *reinterpret_cast<T*>(m_common_iterator[K].data());
    }

private:

    using common_iter = common_iterator;

486
487
488
    void init_common_iterator(const buffer_info &buffer,
                              const std::vector<size_t> &shape,
                              common_iter &iterator, container_type &strides) {
489
490
491
492
493
494
495
        auto buffer_shape_iter = buffer.shape.rbegin();
        auto buffer_strides_iter = buffer.strides.rbegin();
        auto shape_iter = shape.rbegin();
        auto strides_iter = strides.rbegin();

        while (buffer_shape_iter != buffer.shape.rend()) {
            if (*shape_iter == *buffer_shape_iter)
496
                *strides_iter = static_cast<size_t>(*buffer_strides_iter);
497
498
499
500
501
502
503
504
505
506
507
508
509
510
            else
                *strides_iter = 0;

            ++buffer_shape_iter;
            ++buffer_strides_iter;
            ++shape_iter;
            ++strides_iter;
        }

        std::fill(strides_iter, strides.rend(), 0);
        iterator = common_iter(buffer.ptr, strides, shape);
    }

    void increment_common_iterator(size_t dim) {
511
        for (auto &iter : m_common_iterator)
512
513
514
515
516
517
518
519
520
            iter.increment(dim);
    }

    container_type m_shape;
    container_type m_index;
    std::array<common_iter, N> m_common_iterator;
};

template <size_t N>
521
522
bool broadcast(const std::array<buffer_info, N>& buffers, size_t& ndim, std::vector<size_t>& shape) {
    ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
523
524
525
        return std::max(res, buf.ndim);
    });

526
    shape = std::vector<size_t>(ndim, 1);
527
528
529
530
    bool trivial_broadcast = true;
    for (size_t i = 0; i < N; ++i) {
        auto res_iter = shape.rbegin();
        bool i_trivial_broadcast = (buffers[i].size == 1) || (buffers[i].ndim == ndim);
531
532
533
534
        for (auto shape_iter = buffers[i].shape.rbegin();
             shape_iter != buffers[i].shape.rend(); ++shape_iter, ++res_iter) {

            if (*res_iter == 1)
535
                *res_iter = *shape_iter;
536
            else if ((*shape_iter != 1) && (*res_iter != *shape_iter))
537
                pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
538

539
540
541
542
543
544
545
            i_trivial_broadcast = i_trivial_broadcast && (*res_iter == *shape_iter);
        }
        trivial_broadcast = trivial_broadcast && i_trivial_broadcast;
    }
    return trivial_broadcast;
}

546
547
548
549
template <typename Func, typename Return, typename... Args>
struct vectorize_helper {
    typename std::remove_reference<Func>::type f;

550
551
    template <typename T>
    vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
Wenzel Jakob's avatar
Wenzel Jakob committed
552

553
    object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
554
555
        return run(args..., typename make_index_sequence<sizeof...(Args)>::type());
    }
Wenzel Jakob's avatar
Wenzel Jakob committed
556

557
    template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
Wenzel Jakob's avatar
Wenzel Jakob committed
558
        /* Request buffers from all parameters */
559
        const size_t N = sizeof...(Args);
560

Wenzel Jakob's avatar
Wenzel Jakob committed
561
562
563
        std::array<buffer_info, N> buffers {{ args.request()... }};

        /* Determine dimensions parameters of output array */
564
        size_t ndim = 0;
565
566
        std::vector<size_t> shape(0);
        bool trivial_broadcast = broadcast(buffers, ndim, shape);
567

568
        size_t size = 1;
Wenzel Jakob's avatar
Wenzel Jakob committed
569
570
        std::vector<size_t> strides(ndim);
        if (ndim > 0) {
571
            strides[ndim-1] = sizeof(Return);
572
            for (size_t i = ndim - 1; i > 0; --i) {
573
574
575
576
                strides[i - 1] = strides[i] * shape[i];
                size *= shape[i];
            }
            size *= shape[0];
Wenzel Jakob's avatar
Wenzel Jakob committed
577
578
        }

579
        if (size == 1)
580
            return cast(f(*((Args *) buffers[Index].ptr)...));
Wenzel Jakob's avatar
Wenzel Jakob committed
581

582
        array result(buffer_info(nullptr, sizeof(Return),
583
            format_descriptor<Return>::format(),
Wenzel Jakob's avatar
Wenzel Jakob committed
584
            ndim, shape, strides));
585
586
587
588

        buffer_info buf = result.request();
        Return *output = (Return *) buf.ptr;

589
        if (trivial_broadcast) {
590
591
592
            /* Call the function */
            for (size_t i=0; i<size; ++i) {
                output[i] = f((buffers[Index].size == 1
593
594
                               ? *((Args *) buffers[Index].ptr)
                               : ((Args *) buffers[Index].ptr)[i])...);
595
            }
596
        } else {
597
598
            apply_broadcast<N, Index...>(buffers, buf, index);
        }
599
600

        return result;
601
    }
602
603

    template <size_t N, size_t... Index>
604
605
    void apply_broadcast(const std::array<buffer_info, N> &buffers,
                         buffer_info &output, index_sequence<Index...>) {
606
607
608
609
610
611
        using input_iterator = multi_array_iterator<N>;
        using output_iterator = array_iterator<Return>;

        input_iterator input_iter(buffers, output.shape);
        output_iterator output_end = array_end<Return>(output);

612
613
        for (output_iterator iter = array_begin<Return>(output);
             iter != output_end; ++iter, ++input_iter) {
614
615
616
            *iter = f((input_iter.template data<Index, Args>())...);
        }
    }
617
618
};

619
template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
620
    static PYBIND11_DESCR name() { return _("numpy.ndarray[") + type_caster<T>::name() + _("]"); }
621
622
};

623
NAMESPACE_END(detail)
Wenzel Jakob's avatar
Wenzel Jakob committed
624

625
626
627
template <typename Func, typename Return, typename... Args>
detail::vectorize_helper<Func, Return, Args...> vectorize(const Func &f, Return (*) (Args ...)) {
    return detail::vectorize_helper<Func, Return, Args...>(f);
Wenzel Jakob's avatar
Wenzel Jakob committed
628
629
}

630
631
632
template <typename Return, typename... Args>
detail::vectorize_helper<Return (*) (Args ...), Return, Args...> vectorize(Return (*f) (Args ...)) {
    return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
Wenzel Jakob's avatar
Wenzel Jakob committed
633
634
635
636
637
638
639
640
}

template <typename func> auto vectorize(func &&f) -> decltype(
        vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(&std::remove_reference<func>::type::operator())>::type *) nullptr)) {
    return vectorize(std::forward<func>(f), (typename detail::remove_class<decltype(
                   &std::remove_reference<func>::type::operator())>::type *) nullptr);
}

641
NAMESPACE_END(pybind11)
Wenzel Jakob's avatar
Wenzel Jakob committed
642
643
644
645

#if defined(_MSC_VER)
#pragma warning(pop)
#endif