stl_bind.h 11.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
    pybind11/std_bind.h: Binding generators for STL data types

    Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include "common.h"
#include "operators.h"

#include <type_traits>
#include <utility>
#include <algorithm>
#include <sstream>

NAMESPACE_BEGIN(pybind11)
NAMESPACE_BEGIN(detail)

/* SFINAE helper class used by 'is_comparable */
template <typename T>  struct container_traits {
    template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
    template <typename T2> static std::false_type test_comparable(...);
    template <typename T2> static std::true_type test_value(typename T2::value_type *);
    template <typename T2> static std::false_type test_value(...);
    template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
    template <typename T2> static std::false_type test_pair(...);

    static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
    static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
    static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
    static constexpr const bool is_element = !is_pair && !is_vector;
};

/* Default: is_comparable -> std::false_type */
template <typename T, typename SFINAE = void>
struct is_comparable : std::false_type { };

/* For non-map data structures, check whether operator== can be instantiated */
template <typename T>
struct is_comparable<
    T, typename std::enable_if<container_traits<T>::is_element &&
                               container_traits<T>::is_comparable>::type>
    : std::true_type { };

/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
template <typename T>
struct is_comparable<T, typename std::enable_if<container_traits<T>::is_vector>::type> {
    static constexpr const bool value =
        is_comparable<typename T::value_type>::value;
};

/* For pairs, recursively check the two data types */
template <typename T>
struct is_comparable<T, typename std::enable_if<container_traits<T>::is_pair>::type> {
    static constexpr const bool value =
        is_comparable<typename T::first_type>::value &&
        is_comparable<typename T::second_type>::value;
};

/* Fallback functions */
template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args&...) { }
template <typename, typename, typename... Args> void vector_if_equal_operator(const Args&...) { }
template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args&...) { }

template<typename Vector, typename Class_, typename std::enable_if<std::is_copy_constructible<typename Vector::value_type>::value, int>::type = 0>
void vector_if_copy_constructible(Class_ &cl) {
    cl.def(pybind11::init<const Vector &>(),
           "Copy constructor");
}

template<typename Vector, typename Class_, typename std::enable_if<is_comparable<Vector>::value, int>::type = 0>
void vector_if_equal_operator(Class_ &cl) {
    using T = typename Vector::value_type;

    cl.def(self == self);
    cl.def(self != self);

    cl.def("count",
        [](const Vector &v, const T &x) {
            return std::count(v.begin(), v.end(), x);
        },
        arg("x"),
        "Return the number of times ``x`` appears in the list"
    );

    cl.def("remove", [](Vector &v, const T &x) {
            auto p = std::find(v.begin(), v.end(), x);
            if (p != v.end())
                v.erase(p);
            else
                throw pybind11::value_error();
        },
        arg("x"),
        "Remove the first item from the list whose value is x. "
        "It is an error if there is no such item."
    );

    cl.def("__contains__",
        [](const Vector &v, const T &x) {
            return std::find(v.begin(), v.end(), x) != v.end();
        },
        arg("x"),
        "Return true the container contains ``x``"
    );
}

111
template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    -> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
    using size_type = typename Vector::size_type;

    cl.def("__repr__",
           [name](Vector &v) {
            std::ostringstream s;
            s << name << '[';
            for (size_type i=0; i < v.size(); ++i) {
                s << v[i];
                if (i != v.size() - 1)
                    s << ", ";
            }
            s << ']';
            return s.str();
        },
        "Return the canonical string representation of this list."
    );
}

NAMESPACE_END(detail)


template <typename T, typename Allocator = std::allocator<T>, typename holder_type = std::unique_ptr<std::vector<T, Allocator>>, typename... Args>
135
pybind11::class_<std::vector<T, Allocator>, holder_type> bind_vector(pybind11::module &m, std::string const &name, Args&&... args) {
136
137
    using Vector = std::vector<T, Allocator>;
    using SizeType = typename Vector::size_type;
138
    using DiffType = typename Vector::difference_type;
139
    using ItType   = typename Vector::iterator;
140
141
    using Class_ = pybind11::class_<Vector, holder_type>;

142
    Class_ cl(m, name.c_str(), std::forward<Args>(args)...);
143
144
145

    cl.def(pybind11::init<>());

Wenzel Jakob's avatar
Wenzel Jakob committed
146
    // Register copy constructor (if possible)
147
148
    detail::vector_if_copy_constructible<Vector, Class_>(cl);

Wenzel Jakob's avatar
Wenzel Jakob committed
149
150
151
152
153
154
    // Register comparison-related operators and functions (if possible)
    detail::vector_if_equal_operator<Vector, Class_>(cl);

    // Register stream insertion operator (if possible)
    detail::vector_if_insertion_operator<Vector, Class_>(cl, name);

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    cl.def("__init__", [](Vector &v, iterable it) {
        new (&v) Vector();
        try {
            v.reserve(len(it));
            for (handle h : it)
               v.push_back(h.cast<typename Vector::value_type>());
        } catch (...) {
            v.~Vector();
            throw;
        }
    });
    cl.def("append", (void (Vector::*) (const T &)) & Vector::push_back,
           arg("x"),
           "Add an item to the end of the list");

    cl.def("extend",
       [](Vector &v, Vector &src) {
           v.reserve(v.size() + src.size());
           v.insert(v.end(), src.begin(), src.end());
       },
       arg("L"),
       "Extend the list by appending all the items in the given list"
    );

    cl.def("insert",
        [](Vector &v, SizeType i, const T &x) {
181
            v.insert(v.begin() + (DiffType) i, x);
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        },
        arg("i") , arg("x"),
        "Insert an item at a given position."
    );

    cl.def("pop",
        [](Vector &v) {
            if (v.empty())
                throw pybind11::index_error();
            T t = v.back();
            v.pop_back();
            return t;
        },
        "Remove and return the last item"
    );

    cl.def("pop",
        [](Vector &v, SizeType i) {
            if (i >= v.size())
                throw pybind11::index_error();
            T t = v[i];
203
            v.erase(v.begin() + (DiffType) i);
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            return t;
        },
        arg("i"),
        "Remove and return the item at index ``i``"
    );

    cl.def("__bool__",
        [](const Vector &v) -> bool {
            return !v.empty();
        },
        "Check whether the list is nonempty"
    );

    cl.def("__getitem__",
218
        [](const Vector &v, SizeType i) -> T {
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            if (i >= v.size())
                throw pybind11::index_error();
            return v[i];
        }
    );

    cl.def("__setitem__",
        [](Vector &v, SizeType i, const T &t) {
            if (i >= v.size())
                throw pybind11::index_error();
            v[i] = t;
        }
    );

    cl.def("__delitem__",
        [](Vector &v, SizeType i) {
            if (i >= v.size())
                throw pybind11::index_error();
237
            v.erase(v.begin() + typename Vector::difference_type(i));
238
239
240
241
242
243
244
245
        },
        "Delete list elements using a slice object"
    );

    cl.def("__len__", &Vector::size);

    cl.def("__iter__",
        [](Vector &v) {
246
            return pybind11::make_iterator<ItType, T>(v.begin(), v.end());
247
248
249
250
251
252
253
        },
        pybind11::keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
    );

    /// Slicing protocol
    cl.def("__getitem__",
        [](const Vector &v, slice slice) -> Vector * {
254
            size_t start, stop, step, slicelength;
255
256
257
258
259
260
261

            if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
                throw pybind11::error_already_set();

            Vector *seq = new Vector();
            seq->reserve((size_t) slicelength);

262
            for (size_t i=0; i<slicelength; ++i) {
263
264
265
266
267
268
269
270
271
272
273
                seq->push_back(v[start]);
                start += step;
            }
            return seq;
        },
        arg("s"),
        "Retrieve list elements using a slice object"
    );

    cl.def("__setitem__",
        [](Vector &v, slice slice,  const Vector &value) {
274
            size_t start, stop, step, slicelength;
275
276
277
            if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
                throw pybind11::error_already_set();

278
            if (slicelength != value.size())
279
280
                throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");

281
            for (size_t i=0; i<slicelength; ++i) {
282
283
284
285
286
287
288
289
290
                v[start] = value[i];
                start += step;
            }
        },
        "Assign list elements using a slice object"
    );

    cl.def("__delitem__",
        [](Vector &v, slice slice) {
291
            size_t start, stop, step, slicelength;
292
293
294
295
296

            if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
                throw pybind11::error_already_set();

            if (step == 1 && false) {
297
                v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
298
            } else {
299
300
                for (size_t i = 0; i < slicelength; ++i) {
                    v.erase(v.begin() + DiffType(start));
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                    start += step - 1;
                }
            }
        },
        "Delete list elements using a slice object"
    );

#if 0
    // C++ style functions deprecated, leaving it here as an example
    cl.def(pybind11::init<size_type>());

    cl.def("resize",
         (void (Vector::*) (size_type count)) & Vector::resize,
         "changes the number of elements stored");

    cl.def("erase",
        [](Vector &v, SizeType i) {
        if (i >= v.size())
            throw pybind11::index_error();
        v.erase(v.begin() + i);
    }, "erases element at index ``i``");

    cl.def("empty",         &Vector::empty,         "checks whether the container is empty");
    cl.def("size",          &Vector::size,          "returns the number of elements");
    cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
    cl.def("pop_back",                               &Vector::pop_back, "removes the last element");

    cl.def("max_size",      &Vector::max_size,      "returns the maximum possible number of elements");
    cl.def("reserve",       &Vector::reserve,       "reserves storage");
    cl.def("capacity",      &Vector::capacity,      "returns the number of elements that can be held in currently allocated storage");
    cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");

    cl.def("clear", &Vector::clear, "clears the contents");
    cl.def("swap",   &Vector::swap, "swaps the contents");

    cl.def("front", [](Vector &v) {
        if (v.size()) return v.front();
        else throw pybind11::index_error();
    }, "access the first element");

    cl.def("back", [](Vector &v) {
        if (v.size()) return v.back();
        else throw pybind11::index_error();
    }, "access the last element ");

#endif

    return cl;
}

NAMESPACE_END(pybind11)