Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
39ff2d01
Commit
39ff2d01
authored
Aug 03, 2016
by
Wenzel Jakob
Committed by
GitHub
Aug 03, 2016
Browse files
Merge pull request #312 from jagerman/eigen-ref-args
Add support for Eigen::Ref<...> function arguments
parents
7f9603fe
5fd5074a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
1 deletion
+63
-1
example/eigen.cpp
example/eigen.cpp
+15
-0
example/eigen.py
example/eigen.py
+8
-0
example/eigen.ref
example/eigen.ref
+6
-0
include/pybind11/eigen.h
include/pybind11/eigen.h
+34
-1
No files found.
example/eigen.cpp
View file @
39ff2d01
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "example.h"
#include "example.h"
#include <pybind11/eigen.h>
#include <pybind11/eigen.h>
#include <Eigen/Cholesky>
Eigen
::
VectorXf
double_col
(
const
Eigen
::
VectorXf
&
x
)
Eigen
::
VectorXf
double_col
(
const
Eigen
::
VectorXf
&
x
)
{
return
2.0
f
*
x
;
}
{
return
2.0
f
*
x
;
}
...
@@ -19,6 +20,14 @@ Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x)
...
@@ -19,6 +20,14 @@ Eigen::RowVectorXf double_row(const Eigen::RowVectorXf& x)
Eigen
::
MatrixXf
double_mat_cm
(
const
Eigen
::
MatrixXf
&
x
)
Eigen
::
MatrixXf
double_mat_cm
(
const
Eigen
::
MatrixXf
&
x
)
{
return
2.0
f
*
x
;
}
{
return
2.0
f
*
x
;
}
// Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended
Eigen
::
MatrixXd
cholesky1
(
Eigen
::
Ref
<
Eigen
::
MatrixXd
>
&
x
)
{
return
x
.
llt
().
matrixL
();
}
Eigen
::
MatrixXd
cholesky2
(
const
Eigen
::
Ref
<
const
Eigen
::
MatrixXd
>
&
x
)
{
return
x
.
llt
().
matrixL
();
}
Eigen
::
MatrixXd
cholesky3
(
const
Eigen
::
Ref
<
Eigen
::
MatrixXd
>
&
x
)
{
return
x
.
llt
().
matrixL
();
}
Eigen
::
MatrixXd
cholesky4
(
Eigen
::
Ref
<
const
Eigen
::
MatrixXd
>
&
x
)
{
return
x
.
llt
().
matrixL
();
}
Eigen
::
MatrixXd
cholesky5
(
Eigen
::
Ref
<
Eigen
::
MatrixXd
>
x
)
{
return
x
.
llt
().
matrixL
();
}
Eigen
::
MatrixXd
cholesky6
(
Eigen
::
Ref
<
const
Eigen
::
MatrixXd
>
x
)
{
return
x
.
llt
().
matrixL
();
}
typedef
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
MatrixXfRowMajor
;
typedef
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
MatrixXfRowMajor
;
MatrixXfRowMajor
double_mat_rm
(
const
MatrixXfRowMajor
&
x
)
MatrixXfRowMajor
double_mat_rm
(
const
MatrixXfRowMajor
&
x
)
{
return
2.0
f
*
x
;
}
{
return
2.0
f
*
x
;
}
...
@@ -40,6 +49,12 @@ void init_eigen(py::module &m) {
...
@@ -40,6 +49,12 @@ void init_eigen(py::module &m) {
m
.
def
(
"double_row"
,
&
double_row
);
m
.
def
(
"double_row"
,
&
double_row
);
m
.
def
(
"double_mat_cm"
,
&
double_mat_cm
);
m
.
def
(
"double_mat_cm"
,
&
double_mat_cm
);
m
.
def
(
"double_mat_rm"
,
&
double_mat_rm
);
m
.
def
(
"double_mat_rm"
,
&
double_mat_rm
);
m
.
def
(
"cholesky1"
,
&
cholesky1
);
m
.
def
(
"cholesky2"
,
&
cholesky2
);
m
.
def
(
"cholesky3"
,
&
cholesky3
);
m
.
def
(
"cholesky4"
,
&
cholesky4
);
m
.
def
(
"cholesky5"
,
&
cholesky5
);
m
.
def
(
"cholesky6"
,
&
cholesky6
);
m
.
def
(
"fixed_r"
,
[
mat
]()
->
FixedMatrixR
{
m
.
def
(
"fixed_r"
,
[
mat
]()
->
FixedMatrixR
{
return
FixedMatrixR
(
mat
);
return
FixedMatrixR
(
mat
);
...
...
example/eigen.py
View file @
39ff2d01
...
@@ -11,6 +11,7 @@ from example import sparse_r, sparse_c
...
@@ -11,6 +11,7 @@ from example import sparse_r, sparse_c
from
example
import
sparse_passthrough_r
,
sparse_passthrough_c
from
example
import
sparse_passthrough_r
,
sparse_passthrough_c
from
example
import
double_row
,
double_col
from
example
import
double_row
,
double_col
from
example
import
double_mat_cm
,
double_mat_rm
from
example
import
double_mat_cm
,
double_mat_rm
from
example
import
cholesky1
,
cholesky2
,
cholesky3
,
cholesky4
,
cholesky5
,
cholesky6
try
:
try
:
import
numpy
as
np
import
numpy
as
np
import
scipy
import
scipy
...
@@ -70,3 +71,10 @@ slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
...
@@ -70,3 +71,10 @@ slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
for
slice_idx
,
ref_mat
in
enumerate
(
slices
):
for
slice_idx
,
ref_mat
in
enumerate
(
slices
):
print
(
"double_mat_cm(%d) = %s"
%
(
slice_idx
,
check_got_vs_ref
(
double_mat_cm
(
ref_mat
),
2.0
*
ref_mat
)))
print
(
"double_mat_cm(%d) = %s"
%
(
slice_idx
,
check_got_vs_ref
(
double_mat_cm
(
ref_mat
),
2.0
*
ref_mat
)))
print
(
"double_mat_rm(%d) = %s"
%
(
slice_idx
,
check_got_vs_ref
(
double_mat_rm
(
ref_mat
),
2.0
*
ref_mat
)))
print
(
"double_mat_rm(%d) = %s"
%
(
slice_idx
,
check_got_vs_ref
(
double_mat_rm
(
ref_mat
),
2.0
*
ref_mat
)))
i
=
1
for
chol
in
[
cholesky1
,
cholesky2
,
cholesky3
,
cholesky4
,
cholesky5
,
cholesky6
]:
mymat
=
chol
(
np
.
array
([[
1
,
2
,
4
],
[
2
,
13
,
23
],
[
4
,
23
,
77
]]))
print
(
"cholesky"
+
str
(
i
)
+
" "
+
(
"OK"
if
(
mymat
==
np
.
array
([[
1
,
0
,
0
],
[
2
,
3
,
0
],
[
4
,
5
,
6
]])).
all
()
else
"NOT OKAY"
))
i
+=
1
example/eigen.ref
View file @
39ff2d01
...
@@ -27,3 +27,9 @@ double_mat_cm(1) = OK
...
@@ -27,3 +27,9 @@ double_mat_cm(1) = OK
double_mat_rm(1) = OK
double_mat_rm(1) = OK
double_mat_cm(2) = OK
double_mat_cm(2) = OK
double_mat_rm(2) = OK
double_mat_rm(2) = OK
cholesky1 OK
cholesky2 OK
cholesky3 OK
cholesky4 OK
cholesky5 OK
cholesky6 OK
include/pybind11/eigen.h
View file @
39ff2d01
...
@@ -40,6 +40,19 @@ public:
...
@@ -40,6 +40,19 @@ public:
static
constexpr
bool
value
=
decltype
(
test
(
std
::
declval
<
T
>
()))
::
value
;
static
constexpr
bool
value
=
decltype
(
test
(
std
::
declval
<
T
>
()))
::
value
;
};
};
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
// it (since there is no reference!), but we can cast from it.
template
<
typename
T
>
class
is_eigen_ref
{
private:
template
<
typename
Derived
>
static
typename
std
::
enable_if
<
std
::
is_same
<
typename
std
::
remove_const
<
T
>::
type
,
Eigen
::
Ref
<
Derived
>>::
value
,
Derived
>::
type
test
(
const
Eigen
::
Ref
<
Derived
>
&
);
static
void
test
(...);
public:
typedef
decltype
(
test
(
std
::
declval
<
T
>
()))
Derived
;
static
constexpr
bool
value
=
!
std
::
is_void
<
Derived
>::
value
;
};
template
<
typename
T
>
class
is_eigen_sparse
{
template
<
typename
T
>
class
is_eigen_sparse
{
private:
private:
template
<
typename
Derived
>
static
std
::
true_type
test
(
const
Eigen
::
SparseMatrixBase
<
Derived
>
&
);
template
<
typename
Derived
>
static
std
::
true_type
test
(
const
Eigen
::
SparseMatrixBase
<
Derived
>
&
);
...
@@ -49,7 +62,7 @@ public:
...
@@ -49,7 +62,7 @@ public:
};
};
template
<
typename
Type
>
template
<
typename
Type
>
struct
type_caster
<
Type
,
typename
std
::
enable_if
<
is_eigen_dense
<
Type
>::
value
>::
type
>
{
struct
type_caster
<
Type
,
typename
std
::
enable_if
<
is_eigen_dense
<
Type
>::
value
&&
!
is_eigen_ref
<
Type
>::
value
>::
type
>
{
typedef
typename
Type
::
Scalar
Scalar
;
typedef
typename
Type
::
Scalar
Scalar
;
static
constexpr
bool
rowMajor
=
Type
::
Flags
&
Eigen
::
RowMajorBit
;
static
constexpr
bool
rowMajor
=
Type
::
Flags
&
Eigen
::
RowMajorBit
;
static
constexpr
bool
isVector
=
Type
::
IsVectorAtCompileTime
;
static
constexpr
bool
isVector
=
Type
::
IsVectorAtCompileTime
;
...
@@ -149,6 +162,26 @@ protected:
...
@@ -149,6 +162,26 @@ protected:
static
PYBIND11_DESCR
cols
()
{
return
_
<
T
::
ColsAtCompileTime
>
();
}
static
PYBIND11_DESCR
cols
()
{
return
_
<
T
::
ColsAtCompileTime
>
();
}
};
};
template
<
typename
Type
>
struct
type_caster
<
Type
,
typename
std
::
enable_if
<
is_eigen_dense
<
Type
>::
value
&&
is_eigen_ref
<
Type
>::
value
>::
type
>
{
private:
using
Derived
=
typename
std
::
remove_const
<
typename
is_eigen_ref
<
Type
>::
Derived
>::
type
;
using
DerivedCaster
=
type_caster
<
Derived
>
;
DerivedCaster
derived_caster
;
protected:
std
::
unique_ptr
<
Type
>
value
;
public:
bool
load
(
handle
src
,
bool
convert
)
{
if
(
derived_caster
.
load
(
src
,
convert
))
{
value
.
reset
(
new
Type
(
derived_caster
.
operator
Derived
&
()));
return
true
;
}
return
false
;
}
static
handle
cast
(
const
Type
&
src
,
return_value_policy
policy
,
handle
parent
)
{
return
DerivedCaster
::
cast
(
src
,
policy
,
parent
);
}
static
handle
cast
(
const
Type
*
src
,
return_value_policy
policy
,
handle
parent
)
{
return
DerivedCaster
::
cast
(
*
src
,
policy
,
parent
);
}
static
PYBIND11_DESCR
name
()
{
return
DerivedCaster
::
name
();
}
operator
Type
*
()
{
return
value
.
get
();
}
operator
Type
&
()
{
if
(
!
value
)
pybind11_fail
(
"Eigen::Ref<...> value not loaded"
);
return
*
value
;
}
template
<
typename
_T
>
using
cast_op_type
=
pybind11
::
detail
::
cast_op_type
<
_T
>
;
};
template
<
typename
Type
>
template
<
typename
Type
>
struct
type_caster
<
Type
,
typename
std
::
enable_if
<
is_eigen_sparse
<
Type
>::
value
>::
type
>
{
struct
type_caster
<
Type
,
typename
std
::
enable_if
<
is_eigen_sparse
<
Type
>::
value
>::
type
>
{
typedef
typename
Type
::
Scalar
Scalar
;
typedef
typename
Type
::
Scalar
Scalar
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment