Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5e5ed37a
Unverified
Commit
5e5ed37a
authored
May 10, 2022
by
Umang Yadav
Committed by
GitHub
May 10, 2022
Browse files
Expose `add_literal` in C and Python API (#1173)
Expose add_literal method in C/C++ api
parent
ddbbe54b
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
17 deletions
+87
-17
doc/src/reference/py.rst
doc/src/reference/py.rst
+7
-0
src/api/api.cpp
src/api/api.cpp
+16
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+5
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+9
-0
src/api/migraphx.py
src/api/migraphx.py
+3
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+8
-0
test/api/test_module_construct.cpp
test/api/test_module_construct.cpp
+5
-7
test/py/test_module_construct.py
test/py/test_module_construct.py
+12
-10
test/py/test_numpy.py
test/py/test_numpy.py
+22
-0
No files found.
doc/src/reference/py.rst
View file @
5e5ed37a
...
...
@@ -146,6 +146,13 @@ module
:param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction
.. py:method:: add_literal(data)
Adds constant or literal data of provided shape into the module from python buffer which includes numpy array.
:param py::buffer data: Python buffer or numpy array
:rtype instruction
.. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape.
...
...
src/api/api.cpp
View file @
5e5ed37a
...
...
@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_literal
((
shape
->
object
),
(
buffer
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.h
View file @
5e5ed37a
...
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.hpp
View file @
5e5ed37a
...
...
@@ -762,6 +762,15 @@ struct module
return
instruction
(
op_ins
,
own
{});
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
migraphx_instruction_t
param_ins
;
...
...
src/api/migraphx.py
View file @
5e5ed37a
...
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
...
...
src/py/migraphx_py.cpp
View file @
5e5ed37a
...
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"op"
),
py
::
arg
(
"args"
),
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
.
def
(
"add_literal"
,
[](
migraphx
::
module
&
mm
,
py
::
buffer
data
)
{
py
::
buffer_info
info
=
data
.
request
();
auto
literal_shape
=
to_shape
(
info
);
return
mm
.
add_literal
(
literal_shape
,
reinterpret_cast
<
char
*>
(
info
.
ptr
));
},
py
::
arg
(
"data"
))
.
def
(
"add_parameter"
,
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
...
...
test/api/test_module_construct.cpp
View file @
5e5ed37a
...
...
@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE
(
add_
op
)
TEST_CASE
(
add_
literals
)
{
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
param_shape
{
migraphx_shape_float_type
,
{
3
,
3
}};
auto
x
=
m
.
add_parameter
(
"x"
,
param_shape
);
auto
y
=
m
.
add_parameter
(
"y"
,
param_shape
);
std
::
vector
<
float
>
x_values
(
9
,
1
);
auto
x
=
m
.
add_literal
(
param_shape
,
x_values
.
data
());
std
::
vector
<
float
>
y_values
(
9
,
-
1
);
auto
y
=
m
.
add_literal
(
param_shape
,
y_values
.
data
());
auto
add_op
=
migraphx
::
operation
(
"add"
);
auto
r
=
m
.
add_instruction
(
add_op
,
{
x
,
y
});
m
.
add_return
({
r
});
// run on ref target
p
.
compile
(
migraphx
::
target
(
"ref"
));
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
9
,
1
);
std
::
vector
<
float
>
y_data
(
9
,
-
1
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
param_shape
,
x_data
.
data
()));
pp
.
add
(
"y"
,
migraphx
::
argument
(
param_shape
,
y_data
.
data
()));
auto
outputs
=
p
.
eval
(
pp
);
auto
output
=
outputs
[
0
];
std
::
vector
<
float
>
expected
(
9
,
0
);
...
...
test/py/test_module_construct.py
View file @
5e5ed37a
import
migraphx
import
migraphx
,
array
,
sys
def
create_buffer
(
t
,
data
,
shape
):
a
=
array
.
array
(
t
,
data
)
m
=
memoryview
(
a
.
tobytes
())
return
m
.
cast
(
t
,
shape
)
def
test_add_op
():
p
=
migraphx
.
program
()
mm
=
p
.
get_main_module
()
param_shape
=
migraphx
.
shape
(
lens
=
[
3
,
3
],
type
=
"float"
)
x
=
mm
.
add_parameter
(
"x"
,
param_shape
)
y
=
mm
.
add_parameter
(
"y"
,
param_shape
)
x
=
mm
.
add_literal
(
create_buffer
(
'f'
,
[
1.0
]
*
9
,
(
3
,
3
)))
y
=
mm
.
add_literal
(
create_buffer
(
'f'
,
[
2.0
]
*
9
,
(
3
,
3
)))
add_op
=
mm
.
add_instruction
(
migraphx
.
op
(
"add"
),
[
x
,
y
])
mm
.
add_return
([
add_op
])
p
.
compile
(
migraphx
.
get_target
(
"ref"
))
params
=
{}
params
[
"x"
]
=
migraphx
.
generate_argument
(
param_shape
)
params
[
"y"
]
=
migraphx
.
generate_argument
(
param_shape
)
output
=
p
.
run
(
params
)[
-
1
].
tolist
()
assert
output
==
[
a
+
b
for
a
,
b
in
zip
(
params
[
"x"
].
tolist
(),
params
[
"y"
].
tolist
())
]
assert
output
==
list
([
3.0
]
*
9
)
def
test_if_then_else
():
...
...
@@ -60,5 +61,6 @@ def test_if_then_else():
if
__name__
==
"__main__"
:
if
sys
.
version_info
>=
(
3
,
0
):
test_add_op
()
test_if_then_else
()
test/py/test_numpy.py
0 → 100644
View file @
5e5ed37a
import
migraphx
,
sys
try
:
import
numpy
as
np
except
:
sys
.
exit
()
def
test_add_op
():
p
=
migraphx
.
program
()
mm
=
p
.
get_main_module
()
x
=
mm
.
add_literal
(
np
.
ones
((
3
,
3
),
dtype
=
'float32'
))
y
=
mm
.
add_literal
(
2
*
np
.
ones
((
3
,
3
),
dtype
=
'float32'
))
add_op
=
mm
.
add_instruction
(
migraphx
.
op
(
"add"
),
[
x
,
y
])
mm
.
add_return
([
add_op
])
p
.
compile
(
migraphx
.
get_target
(
"ref"
))
params
=
{}
output
=
p
.
run
(
params
)[
-
1
].
tolist
()
assert
output
==
list
(
3
*
np
.
ones
((
9
),
dtype
=
'float32'
))
if
__name__
==
"__main__"
:
test_add_op
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment