eigen.h 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*
    pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices

    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#pragma once

#include "numpy.h"
13

Wenzel Jakob's avatar
Wenzel Jakob committed
14
15
16
#if defined(__INTEL_COMPILER)
#  pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
#elif defined(__GNUG__) || defined(__clang__)
17
18
#  pragma GCC diagnostic push
#  pragma GCC diagnostic ignored "-Wconversion"
19
#  pragma GCC diagnostic ignored "-Wdeprecated-declarations"
20
21
#endif

22
23
24
#include <Eigen/Core>
#include <Eigen/SparseCore>

25
26
27
28
#if defined(__GNUG__) || defined(__clang__)
#  pragma GCC diagnostic pop
#endif

29
30
31
32
33
34
35
36
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif

NAMESPACE_BEGIN(pybind11)
NAMESPACE_BEGIN(detail)

37
38
39
template <typename T> using is_eigen_dense = is_template_base_of<Eigen::DenseBase, T>;
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
template <typename T> using is_eigen_ref = is_template_base_of<Eigen::RefBase, T>;
40

41
42
43
44
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above.  This
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
// matrix data layout that can be copied from their .data().  For example, DiagonalMatrix and
// SelfAdjointView fall into this category.
45
46
47
48
template <typename T> using is_eigen_base = bool_constant<
    is_template_base_of<Eigen::EigenBase, T>::value
    && !is_eigen_dense<T>::value && !is_eigen_sparse<T>::value
>;
49

50
template<typename Type>
51
struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>> {
52
53
    typedef typename Type::Scalar Scalar;
    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
54
    static constexpr bool isVector = Type::IsVectorAtCompileTime;
55
56

    bool load(handle src, bool) {
57
58
59
        array_t<Scalar> buf(src, true);
        if (!buf.check())
            return false;
60

61
        if (buf.ndim() == 1) {
62
            typedef Eigen::InnerStride<> Strides;
63
            if (!isVector &&
64
65
66
67
68
                !(Type::RowsAtCompileTime == Eigen::Dynamic &&
                  Type::ColsAtCompileTime == Eigen::Dynamic))
                return false;

            if (Type::SizeAtCompileTime != Eigen::Dynamic &&
69
                buf.shape(0) != (size_t) Type::SizeAtCompileTime)
70
71
                return false;

72
            Strides::Index n_elts = (Strides::Index) buf.shape(0);
73
            Strides::Index unity = 1;
74
75

            value = Eigen::Map<Type, 0, Strides>(
76
77
78
79
80
81
                buf.mutable_data(),
                rowMajor ? unity : n_elts,
                rowMajor ? n_elts : unity,
                Strides(buf.strides(0) / sizeof(Scalar))
            );
        } else if (buf.ndim() == 2) {
82
83
            typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;

84
85
            if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) ||
                (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime))
86
87
88
                return false;

            value = Eigen::Map<Type, 0, Strides>(
89
90
91
92
93
94
                buf.mutable_data(),
                typename Strides::Index(buf.shape(0)),
                typename Strides::Index(buf.shape(1)),
                Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar),
                        buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar))
            );
95
96
97
98
99
100
101
        } else {
            return false;
        }
        return true;
    }

    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
102
        if (isVector) {
103
104
105
106
107
            return array(
                { (size_t) src.size() },                                      // shape
                { sizeof(Scalar) * static_cast<size_t>(src.innerStride()) },  // strides
                src.data()                                                    // data
            ).release();
108
        } else {
109
110
            return array(
                { (size_t) src.rows(),                                        // shape
111
                  (size_t) src.cols() },
112
113
114
115
                { sizeof(Scalar) * static_cast<size_t>(src.rowStride()),      // strides
                  sizeof(Scalar) * static_cast<size_t>(src.colStride()) },
                src.data()                                                    // data
            ).release();
116
        }
117
118
    }

119
120
    PYBIND11_TYPE_CASTER(Type, _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
            _("[") + rows() + _(", ") + cols() + _("]]"));
121

122
protected:
123
    template <typename T = Type, enable_if_t<T::RowsAtCompileTime == Eigen::Dynamic, int> = 0>
124
    static PYBIND11_DESCR rows() { return _("m"); }
125
    template <typename T = Type, enable_if_t<T::RowsAtCompileTime != Eigen::Dynamic, int> = 0>
126
    static PYBIND11_DESCR rows() { return _<T::RowsAtCompileTime>(); }
127
    template <typename T = Type, enable_if_t<T::ColsAtCompileTime == Eigen::Dynamic, int> = 0>
128
    static PYBIND11_DESCR cols() { return _("n"); }
129
    template <typename T = Type, enable_if_t<T::ColsAtCompileTime != Eigen::Dynamic, int> = 0>
130
131
132
    static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
};

133
134
135
136
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructable, so it needs a special
// type_caster to handle argument copying/forwarding.
template <typename CVDerived, int Options, typename StrideType>
struct type_caster<Eigen::Ref<CVDerived, Options, StrideType>> {
137
protected:
138
139
    using Type = Eigen::Ref<CVDerived, Options, StrideType>;
    using Derived = typename std::remove_const<CVDerived>::type;
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    using DerivedCaster = type_caster<Derived>;
    DerivedCaster derived_caster;
    std::unique_ptr<Type> value;
public:
    bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
    static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
    static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }

    static PYBIND11_DESCR name() { return DerivedCaster::name(); }

    operator Type*() { return value.get(); }
    operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
};

155
156
157
// type_caster for special matrix types (e.g. DiagonalMatrix): load() is not supported, but we can
// cast them into the python domain by first copying to a regular Eigen::Matrix, then casting that.
template <typename Type>
158
struct type_caster<Type, enable_if_t<is_eigen_base<Type>::value && !is_eigen_ref<Type>::value>> {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
protected:
    using Matrix = Eigen::Matrix<typename Type::Scalar, Eigen::Dynamic, Eigen::Dynamic>;
    using MatrixCaster = type_caster<Matrix>;
public:
    [[noreturn]] bool load(handle, bool) { pybind11_fail("Unable to load() into specialized EigenBase object"); }
    static handle cast(const Type &src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(src), policy, parent); }
    static handle cast(const Type *src, return_value_policy policy, handle parent) { return MatrixCaster::cast(Matrix(*src), policy, parent); }

    static PYBIND11_DESCR name() { return MatrixCaster::name(); }

    [[noreturn]] operator Type*() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
    [[noreturn]] operator Type&() { pybind11_fail("Loading not supported for specialized EigenBase object"); }
    template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
};

174
template<typename Type>
175
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
176
177
178
179
180
181
    typedef typename Type::Scalar Scalar;
    typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
    typedef typename Type::Index Index;
    static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;

    bool load(handle src, bool) {
182
183
184
        if (!src)
            return false;

185
186
187
188
189
190
191
        object obj(src, true);
        object sparse_module = module::import("scipy.sparse");
        object matrix_type = sparse_module.attr(
            rowMajor ? "csr_matrix" : "csc_matrix");

        if (obj.get_type() != matrix_type.ptr()) {
            try {
192
                obj = matrix_type(obj);
193
194
195
196
197
            } catch (const error_already_set &) {
                return false;
            }
        }

198
199
200
        auto values = array_t<Scalar>((object) obj.attr("data"));
        auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
        auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
201
202
203
        auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
        auto nnz = obj.attr("nnz").cast<Index>();

204
        if (!values.check() || !innerIndices.check() || !outerIndices.check())
205
206
207
            return false;

        value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
208
209
            shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
            outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
210
211
212
213
214
215
216
217
218
219

        return true;
    }

    static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
        const_cast<Type&>(src).makeCompressed();

        object matrix_type = module::import("scipy.sparse").attr(
            rowMajor ? "csr_matrix" : "csc_matrix");

220
221
222
        array data((size_t) src.nonZeros(), src.valuePtr());
        array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
        array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
223

224
        return matrix_type(
225
226
227
228
229
            std::make_tuple(data, innerIndices, outerIndices),
            std::make_pair(src.rows(), src.cols())
        ).release();
    }

230
    PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
231
            + npy_format_descriptor<Scalar>::name() + _("]"));
232
233
234
235
236
237
238
239
};

NAMESPACE_END(detail)
NAMESPACE_END(pybind11)

#if defined(_MSC_VER)
#pragma warning(pop)
#endif