"src/targets/gpu/tanh.cpp" did not exist on "fdce629342fa5f694e790d5f4c04a964a237aa8a"
Unverified Commit 77c05035 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add migraphx.create_argument to python API (#1722)

parent 0689f873
...@@ -151,6 +151,15 @@ argument ...@@ -151,6 +151,15 @@ argument
:rtype: argument :rtype: argument
.. py:function:: create_argument(s, values)
Create an argument of shape s with a set of values.
:param shape s: Shape of argument to create.
:param list values: Values to put in the argument. Must be the same number of elements as the shape.
:rtype: argument
.. py:function:: argument_from_pointer(shape, address) .. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy. Create argument from data stored in given address without copy.
......
...@@ -93,6 +93,16 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument> ...@@ -93,6 +93,16 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
/// Return the ith element /// Return the ith element
argument element(std::size_t i) const; argument element(std::size_t i) const;
// Keeps the same data ordering as the given container
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) <= m_shape.elements());
this->visit([&](auto output) {
std::copy(start, end, output.begin());
});
}
private: private:
void assign_buffer(std::function<char*()> d); void assign_buffer(std::function<char*()> d);
struct data_t struct data_t
......
...@@ -547,6 +547,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -547,6 +547,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("format") = "msgpack"); py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target); m.def("get_target", &migraphx::make_target);
m.def("create_argument", [](const migraphx::shape& s, const std::vector<double>& values) {
if(values.size() != s.elements())
MIGRAPHX_THROW("Values and shape elements do not match");
migraphx::argument a{s};
a.fill(values.begin(), values.end());
return a;
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value")); m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16", m.def("quantize_fp16",
......
...@@ -51,6 +51,7 @@ add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) ...@@ -51,6 +51,7 @@ add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(literal test_literal.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import migraphx
def test_add_fill():
p = migraphx.program()
mm = p.get_main_module()
x = mm.add_literal(
migraphx.fill_argument(migraphx.shape(type='float_type', lens=[3, 3]),
1))
y = mm.add_literal(
migraphx.fill_argument(migraphx.shape(type='float_type', lens=[3, 3]),
2))
add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op])
p.compile(migraphx.get_target("ref"))
params = {}
output = p.run(params)[-1].tolist()
assert output == list([3.0] * 9)
def test_add_create():
p = migraphx.program()
mm = p.get_main_module()
x = mm.add_literal(
migraphx.create_argument(
migraphx.shape(type='float_type', lens=[2, 2]), [1, 2, 3, 4]))
y = mm.add_literal(
migraphx.create_argument(
migraphx.shape(type='float_type', lens=[2, 2]), [5, 6, 7, 8]))
add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op])
p.compile(migraphx.get_target("ref"))
params = {}
output = p.run(params)[-1].tolist()
assert output == list([6, 8, 10, 12])
if __name__ == "__main__":
test_add_fill()
test_add_create()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment