Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
pybind11
Commits
cc8ff165
Commit
cc8ff165
authored
Oct 31, 2016
by
Ivan Smirnov
Browse files
Move register_dtype() outside of the template
(avoid code bloat if possible)
parent
f95fda0e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
53 deletions
+66
-53
include/pybind11/numpy.h
include/pybind11/numpy.h
+66
-53
No files found.
include/pybind11/numpy.h
View file @
cc8ff165
...
@@ -81,14 +81,18 @@ struct numpy_type_info {
...
@@ -81,14 +81,18 @@ struct numpy_type_info {
struct
numpy_internals
{
struct
numpy_internals
{
std
::
unordered_map
<
std
::
type_index
,
numpy_type_info
>
registered_dtypes
;
std
::
unordered_map
<
std
::
type_index
,
numpy_type_info
>
registered_dtypes
;
template
<
typename
T
>
numpy_type_info
*
get_type_info
(
bool
throw_if_missing
=
true
)
{
numpy_type_info
*
get_type_info
(
const
std
::
type_info
&
tinfo
,
bool
throw_if_missing
=
true
)
{
auto
it
=
registered_dtypes
.
find
(
std
::
type_index
(
t
ypeid
(
T
)
));
auto
it
=
registered_dtypes
.
find
(
std
::
type_index
(
t
info
));
if
(
it
!=
registered_dtypes
.
end
())
if
(
it
!=
registered_dtypes
.
end
())
return
&
(
it
->
second
);
return
&
(
it
->
second
);
if
(
throw_if_missing
)
if
(
throw_if_missing
)
pybind11_fail
(
std
::
string
(
"NumPy type info missing for "
)
+
t
ypeid
(
T
)
.
name
());
pybind11_fail
(
std
::
string
(
"NumPy type info missing for "
)
+
t
info
.
name
());
return
nullptr
;
return
nullptr
;
}
}
template
<
typename
T
>
numpy_type_info
*
get_type_info
(
bool
throw_if_missing
=
true
)
{
return
get_type_info
(
typeid
(
typename
std
::
remove_cv
<
T
>::
type
),
throw_if_missing
);
}
};
};
inline
PYBIND11_NOINLINE
void
load_numpy_internals
(
numpy_internals
*
&
ptr
)
{
inline
PYBIND11_NOINLINE
void
load_numpy_internals
(
numpy_internals
*
&
ptr
)
{
...
@@ -686,6 +690,62 @@ struct field_descriptor {
...
@@ -686,6 +690,62 @@ struct field_descriptor {
dtype
descr
;
dtype
descr
;
};
};
inline
PYBIND11_NOINLINE
void
register_structured_dtype
(
const
std
::
initializer_list
<
field_descriptor
>&
fields
,
const
std
::
type_info
&
tinfo
,
size_t
itemsize
,
bool
(
*
direct_converter
)(
PyObject
*
,
void
*&
))
{
auto
&
numpy_internals
=
get_numpy_internals
();
if
(
numpy_internals
.
get_type_info
(
tinfo
,
false
))
pybind11_fail
(
"NumPy: dtype is already registered"
);
list
names
,
formats
,
offsets
;
for
(
auto
field
:
fields
)
{
if
(
!
field
.
descr
)
pybind11_fail
(
std
::
string
(
"NumPy: unsupported field dtype: `"
)
+
field
.
name
+
"` @ "
+
tinfo
.
name
());
names
.
append
(
PYBIND11_STR_TYPE
(
field
.
name
));
formats
.
append
(
field
.
descr
);
offsets
.
append
(
pybind11
::
int_
(
field
.
offset
));
}
auto
dtype_ptr
=
pybind11
::
dtype
(
names
,
formats
,
offsets
,
itemsize
).
release
().
ptr
();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std
::
vector
<
field_descriptor
>
ordered_fields
(
fields
);
std
::
sort
(
ordered_fields
.
begin
(),
ordered_fields
.
end
(),
[](
const
field_descriptor
&
a
,
const
field_descriptor
&
b
)
{
return
a
.
offset
<
b
.
offset
;
});
size_t
offset
=
0
;
std
::
ostringstream
oss
;
oss
<<
"T{"
;
for
(
auto
&
field
:
ordered_fields
)
{
if
(
field
.
offset
>
offset
)
oss
<<
(
field
.
offset
-
offset
)
<<
'x'
;
// note that '=' is required to cover the case of unaligned fields
oss
<<
'='
<<
field
.
format
<<
':'
<<
field
.
name
<<
':'
;
offset
=
field
.
offset
+
field
.
size
;
}
if
(
itemsize
>
offset
)
oss
<<
(
itemsize
-
offset
)
<<
'x'
;
oss
<<
'}'
;
auto
format_str
=
oss
.
str
();
// Sanity check: verify that NumPy properly parses our buffer format string
auto
&
api
=
npy_api
::
get
();
auto
arr
=
array
(
buffer_info
(
nullptr
,
itemsize
,
format_str
,
1
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
arr
.
dtype
().
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
auto
tindex
=
std
::
type_index
(
tinfo
);
numpy_internals
.
registered_dtypes
[
tindex
]
=
{
dtype_ptr
,
format_str
};
get_internals
().
direct_conversions
[
tindex
].
push_back
(
direct_converter
);
}
template
<
typename
T
>
template
<
typename
T
>
struct
npy_format_descriptor
<
T
,
enable_if_t
<
is_pod_struct
<
T
>::
value
>>
{
struct
npy_format_descriptor
<
T
,
enable_if_t
<
is_pod_struct
<
T
>::
value
>>
{
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
static
PYBIND11_DESCR
name
()
{
return
_
(
"struct"
);
}
...
@@ -699,56 +759,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
...
@@ -699,56 +759,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
return
format_str
;
return
format_str
;
}
}
static
void
register_dtype
(
std
::
initializer_list
<
field_descriptor
>
fields
)
{
static
void
register_dtype
(
const
std
::
initializer_list
<
field_descriptor
>&
fields
)
{
auto
&
numpy_internals
=
get_numpy_internals
();
register_structured_dtype
(
fields
,
typeid
(
typename
std
::
remove_cv
<
T
>::
type
),
if
(
numpy_internals
.
get_type_info
<
T
>
(
false
))
sizeof
(
T
),
&
direct_converter
);
pybind11_fail
(
"NumPy: dtype is already registered"
);
list
names
,
formats
,
offsets
;
for
(
auto
field
:
fields
)
{
if
(
!
field
.
descr
)
pybind11_fail
(
std
::
string
(
"NumPy: unsupported field dtype: `"
)
+
field
.
name
+
"` @ "
+
typeid
(
T
).
name
());
names
.
append
(
PYBIND11_STR_TYPE
(
field
.
name
));
formats
.
append
(
field
.
descr
);
offsets
.
append
(
pybind11
::
int_
(
field
.
offset
));
}
auto
dtype_ptr
=
pybind11
::
dtype
(
names
,
formats
,
offsets
,
sizeof
(
T
)).
release
().
ptr
();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std
::
vector
<
field_descriptor
>
ordered_fields
(
fields
);
std
::
sort
(
ordered_fields
.
begin
(),
ordered_fields
.
end
(),
[](
const
field_descriptor
&
a
,
const
field_descriptor
&
b
)
{
return
a
.
offset
<
b
.
offset
;
});
size_t
offset
=
0
;
std
::
ostringstream
oss
;
oss
<<
"T{"
;
for
(
auto
&
field
:
ordered_fields
)
{
if
(
field
.
offset
>
offset
)
oss
<<
(
field
.
offset
-
offset
)
<<
'x'
;
// note that '=' is required to cover the case of unaligned fields
oss
<<
'='
<<
field
.
format
<<
':'
<<
field
.
name
<<
':'
;
offset
=
field
.
offset
+
field
.
size
;
}
if
(
sizeof
(
T
)
>
offset
)
oss
<<
(
sizeof
(
T
)
-
offset
)
<<
'x'
;
oss
<<
'}'
;
auto
format_str
=
oss
.
str
();
// Sanity check: verify that NumPy properly parses our buffer format string
auto
&
api
=
npy_api
::
get
();
auto
arr
=
array
(
buffer_info
(
nullptr
,
sizeof
(
T
),
format_str
,
1
));
if
(
!
api
.
PyArray_EquivTypes_
(
dtype_ptr
,
arr
.
dtype
().
ptr
()))
pybind11_fail
(
"NumPy: invalid buffer descriptor!"
);
auto
tindex
=
std
::
type_index
(
typeid
(
T
));
numpy_internals
.
registered_dtypes
[
tindex
]
=
{
dtype_ptr
,
format_str
};
get_internals
().
direct_conversions
[
tindex
].
push_back
(
direct_converter
);
}
}
private:
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment