Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5e5ed37a
Unverified
Commit
5e5ed37a
authored
May 10, 2022
by
Umang Yadav
Committed by
GitHub
May 10, 2022
Browse files
Expose `add_literal` in C and Python API (#1173)
Expose add_literal method in C/C++ api
parent
ddbbe54b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
17 deletions
+87
-17
doc/src/reference/py.rst
doc/src/reference/py.rst
+7
-0
src/api/api.cpp
src/api/api.cpp
+16
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+5
-0
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+9
-0
src/api/migraphx.py
src/api/migraphx.py
+3
-0
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+8
-0
test/api/test_module_construct.cpp
test/api/test_module_construct.cpp
+5
-7
test/py/test_module_construct.py
test/py/test_module_construct.py
+12
-10
test/py/test_numpy.py
test/py/test_numpy.py
+22
-0
No files found.
doc/src/reference/py.rst
View file @
5e5ed37a
...
@@ -146,6 +146,13 @@ module
...
@@ -146,6 +146,13 @@ module
:param list[module] mod_args: optional list of module arguments to the operator.
:param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction
:rtype instruction
.. py:method:: add_literal(data)
Adds constant or literal data of provided shape into the module from python buffer which includes numpy array.
:param py::buffer data: Python buffer or numpy array
:rtype instruction
.. py:method:: add_parameter(name, shape)
.. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape.
Adds a parameter to the module with provided name and shape.
...
...
src/api/api.cpp
View file @
5e5ed37a
...
@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
...
@@ -1072,6 +1072,22 @@ migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
return
api_error_result
;
return
api_error_result
;
}
}
extern
"C"
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
)
{
auto
api_error_result
=
migraphx
::
try_
([
&
]
{
if
(
module
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter module: Null pointer"
);
if
(
shape
==
nullptr
)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Bad parameter shape: Null pointer"
);
*
out
=
allocate
<
migraphx_instruction_t
>
(
(
module
->
object
).
add_literal
((
shape
->
object
),
(
buffer
)));
});
return
api_error_result
;
}
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
extern
"C"
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const
char
*
name
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.h
View file @
5e5ed37a
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
...
@@ -258,6 +258,11 @@ migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instructi
migraphx_instructions_t
args
,
migraphx_instructions_t
args
,
migraphx_modules_t
module_refs
);
migraphx_modules_t
module_refs
);
migraphx_status
migraphx_module_add_literal
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
const_migraphx_shape_t
shape
,
const
char
*
buffer
);
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_status
migraphx_module_add_parameter
(
migraphx_instruction_t
*
out
,
migraphx_module_t
module
,
migraphx_module_t
module
,
const
char
*
name
,
const
char
*
name
,
...
...
src/api/include/migraphx/migraphx.hpp
View file @
5e5ed37a
...
@@ -762,6 +762,15 @@ struct module
...
@@ -762,6 +762,15 @@ struct module
return
instruction
(
op_ins
,
own
{});
return
instruction
(
op_ins
,
own
{});
}
}
template
<
typename
T
>
instruction
add_literal
(
const
migraphx
::
shape
&
s
,
T
*
buffer
)
{
migraphx_instruction_t
literal_ins
;
const
auto
*
buffer_ptr
=
reinterpret_cast
<
const
char
*>
(
buffer
);
call
(
&
migraphx_module_add_literal
,
&
literal_ins
,
mm
.
get
(),
s
.
get_handle_ptr
(),
buffer_ptr
);
return
instruction
(
literal_ins
,
own
{});
}
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
instruction
add_parameter
(
const
std
::
string
&
name
,
shape
s
)
{
{
migraphx_instruction_t
param_ins
;
migraphx_instruction_t
param_ins
;
...
...
src/api/migraphx.py
View file @
5e5ed37a
...
@@ -212,6 +212,9 @@ def module(h):
...
@@ -212,6 +212,9 @@ def module(h):
module_refs
=
'std::vector<migraphx::module*>'
),
module_refs
=
'std::vector<migraphx::module*>'
),
fname
=
'add_instruction'
,
fname
=
'add_instruction'
,
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_literal'
,
api
.
params
(
shape
=
'const migraphx::shape&'
,
buffer
=
'const char*'
),
returns
=
'migraphx::instruction_ref'
)
h
.
method
(
'add_parameter'
,
h
.
method
(
'add_parameter'
,
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
api
.
params
(
name
=
'const char*'
,
shape
=
'const migraphx::shape&'
),
returns
=
'migraphx::instruction_ref'
)
returns
=
'migraphx::instruction_ref'
)
...
...
src/py/migraphx_py.cpp
View file @
5e5ed37a
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
arg
(
"op"
),
py
::
arg
(
"op"
),
py
::
arg
(
"args"
),
py
::
arg
(
"args"
),
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
py
::
arg
(
"mod_args"
)
=
std
::
vector
<
migraphx
::
module
*>
{})
.
def
(
"add_literal"
,
[](
migraphx
::
module
&
mm
,
py
::
buffer
data
)
{
py
::
buffer_info
info
=
data
.
request
();
auto
literal_shape
=
to_shape
(
info
);
return
mm
.
add_literal
(
literal_shape
,
reinterpret_cast
<
char
*>
(
info
.
ptr
));
},
py
::
arg
(
"data"
))
.
def
(
.
def
(
"add_parameter"
,
"add_parameter"
,
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
[](
migraphx
::
module
&
mm
,
const
std
::
string
&
name
,
const
migraphx
::
shape
shape
)
{
...
...
test/api/test_module_construct.cpp
View file @
5e5ed37a
...
@@ -3,23 +3,21 @@
...
@@ -3,23 +3,21 @@
#include <migraphx/migraphx.hpp>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
#include "test.hpp"
TEST_CASE
(
add_
op
)
TEST_CASE
(
add_
literals
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
module
m
=
p
.
get_main_module
();
migraphx
::
shape
param_shape
{
migraphx_shape_float_type
,
{
3
,
3
}};
migraphx
::
shape
param_shape
{
migraphx_shape_float_type
,
{
3
,
3
}};
auto
x
=
m
.
add_parameter
(
"x"
,
param_shape
);
std
::
vector
<
float
>
x_values
(
9
,
1
);
auto
y
=
m
.
add_parameter
(
"y"
,
param_shape
);
auto
x
=
m
.
add_literal
(
param_shape
,
x_values
.
data
());
std
::
vector
<
float
>
y_values
(
9
,
-
1
);
auto
y
=
m
.
add_literal
(
param_shape
,
y_values
.
data
());
auto
add_op
=
migraphx
::
operation
(
"add"
);
auto
add_op
=
migraphx
::
operation
(
"add"
);
auto
r
=
m
.
add_instruction
(
add_op
,
{
x
,
y
});
auto
r
=
m
.
add_instruction
(
add_op
,
{
x
,
y
});
m
.
add_return
({
r
});
m
.
add_return
({
r
});
// run on ref target
// run on ref target
p
.
compile
(
migraphx
::
target
(
"ref"
));
p
.
compile
(
migraphx
::
target
(
"ref"
));
migraphx
::
program_parameters
pp
;
migraphx
::
program_parameters
pp
;
std
::
vector
<
float
>
x_data
(
9
,
1
);
std
::
vector
<
float
>
y_data
(
9
,
-
1
);
pp
.
add
(
"x"
,
migraphx
::
argument
(
param_shape
,
x_data
.
data
()));
pp
.
add
(
"y"
,
migraphx
::
argument
(
param_shape
,
y_data
.
data
()));
auto
outputs
=
p
.
eval
(
pp
);
auto
outputs
=
p
.
eval
(
pp
);
auto
output
=
outputs
[
0
];
auto
output
=
outputs
[
0
];
std
::
vector
<
float
>
expected
(
9
,
0
);
std
::
vector
<
float
>
expected
(
9
,
0
);
...
...
test/py/test_module_construct.py
View file @
5e5ed37a
import
migraphx
import
migraphx
,
array
,
sys
def
create_buffer
(
t
,
data
,
shape
):
a
=
array
.
array
(
t
,
data
)
m
=
memoryview
(
a
.
tobytes
())
return
m
.
cast
(
t
,
shape
)
def
test_add_op
():
def
test_add_op
():
p
=
migraphx
.
program
()
p
=
migraphx
.
program
()
mm
=
p
.
get_main_module
()
mm
=
p
.
get_main_module
()
param_shape
=
migraphx
.
shape
(
lens
=
[
3
,
3
],
type
=
"float"
)
x
=
mm
.
add_literal
(
create_buffer
(
'f'
,
[
1.0
]
*
9
,
(
3
,
3
)))
x
=
mm
.
add_parameter
(
"x"
,
param_shape
)
y
=
mm
.
add_literal
(
create_buffer
(
'f'
,
[
2.0
]
*
9
,
(
3
,
3
)))
y
=
mm
.
add_parameter
(
"y"
,
param_shape
)
add_op
=
mm
.
add_instruction
(
migraphx
.
op
(
"add"
),
[
x
,
y
])
add_op
=
mm
.
add_instruction
(
migraphx
.
op
(
"add"
),
[
x
,
y
])
mm
.
add_return
([
add_op
])
mm
.
add_return
([
add_op
])
p
.
compile
(
migraphx
.
get_target
(
"ref"
))
p
.
compile
(
migraphx
.
get_target
(
"ref"
))
params
=
{}
params
=
{}
params
[
"x"
]
=
migraphx
.
generate_argument
(
param_shape
)
params
[
"y"
]
=
migraphx
.
generate_argument
(
param_shape
)
output
=
p
.
run
(
params
)[
-
1
].
tolist
()
output
=
p
.
run
(
params
)[
-
1
].
tolist
()
assert
output
==
[
assert
output
==
list
([
3.0
]
*
9
)
a
+
b
for
a
,
b
in
zip
(
params
[
"x"
].
tolist
(),
params
[
"y"
].
tolist
())
]
def
test_if_then_else
():
def
test_if_then_else
():
...
@@ -60,5 +61,6 @@ def test_if_then_else():
...
@@ -60,5 +61,6 @@ def test_if_then_else():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_add_op
()
if
sys
.
version_info
>=
(
3
,
0
):
test_add_op
()
test_if_then_else
()
test_if_then_else
()
test/py/test_numpy.py
0 → 100644
View file @
5e5ed37a
import
migraphx
,
sys
try
:
import
numpy
as
np
except
:
sys
.
exit
()
def
test_add_op
():
p
=
migraphx
.
program
()
mm
=
p
.
get_main_module
()
x
=
mm
.
add_literal
(
np
.
ones
((
3
,
3
),
dtype
=
'float32'
))
y
=
mm
.
add_literal
(
2
*
np
.
ones
((
3
,
3
),
dtype
=
'float32'
))
add_op
=
mm
.
add_instruction
(
migraphx
.
op
(
"add"
),
[
x
,
y
])
mm
.
add_return
([
add_op
])
p
.
compile
(
migraphx
.
get_target
(
"ref"
))
params
=
{}
output
=
p
.
run
(
params
)[
-
1
].
tolist
()
assert
output
==
list
(
3
*
np
.
ones
((
9
),
dtype
=
'float32'
))
if
__name__
==
"__main__"
:
test_add_op
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment