Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c570fb57
Commit
c570fb57
authored
Feb 01, 2019
by
Paul
Browse files
Formatting
parent
389f556d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
21 deletions
+15
-21
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+15
-21
No files found.
src/py/migraphx_py.cpp
View file @
c570fb57
...
...
@@ -6,12 +6,12 @@
namespace
py
=
pybind11
;
template
<
class
F
>
template
<
class
F
>
struct
skip_half
{
F
f
;
template
<
class
A
>
template
<
class
A
>
void
operator
()(
A
a
)
const
{
f
(
a
);
...
...
@@ -20,34 +20,33 @@ struct skip_half
void
operator
()(
migraphx
::
shape
::
as
<
migraphx
::
half
>
)
const
{
throw
std
::
runtime_error
(
"Half not supported in python yet."
);
}
}
};
template
<
class
F
>
template
<
class
F
>
void
visit_type
(
const
migraphx
::
shape
&
s
,
F
f
)
{
s
.
visit_type
(
skip_half
<
F
>
{
f
});
}
template
<
class
T
>
template
<
class
T
>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
migraphx
::
shape
s
=
x
.
get_shape
();
py
::
buffer_info
b
;
visit_type
(
s
,
[
&
](
auto
as
)
{
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
lens
(),
s
.
strides
()
);
b
=
py
::
buffer_info
(
x
.
data
(),
as
.
size
(),
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
lens
(),
s
.
strides
());
});
return
b
;
}
PYBIND11_MODULE
(
migraphx
,
m
)
{
PYBIND11_MODULE
(
migraphx
,
m
)
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
.
def
(
py
::
init
<>
())
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
...
...
@@ -63,15 +62,11 @@ PYBIND11_MODULE(migraphx, m) {
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
);
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
});
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
});
py
::
class_
<
migraphx
::
program
>
(
m
,
"program"
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
)
{
p
.
compile
(
t
);
})
.
def
(
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
)
{
p
.
compile
(
t
);
})
.
def
(
"eval"
,
&
migraphx
::
program
::
eval
);
m
.
def
(
"parse_onnx"
,
&
migraphx
::
parse_onnx
);
...
...
@@ -82,4 +77,3 @@ PYBIND11_MODULE(migraphx, m) {
m
.
attr
(
"__version__"
)
=
"dev"
;
#endif
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment