Unverified Commit 4467c158 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add python API to construct shape class (#1128)

Add python API to construct shape class
parent 0e6bd17c
......@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", std::string{"float"}));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
......
......@@ -25,10 +25,11 @@ endforeach()
add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(backend onnx_backend_test.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif()
import migraphx
def test_create_shape():
s = migraphx.shape(lens=[1, 64, 3, 3])
assert s.standard()
assert s.packed()
assert s.lens() == [1, 64, 3, 3]
def test_create_shape_broadcast():
s = migraphx.shape(lens=[1, 64, 3, 3], strides=[0, 1, 0, 0])
assert s.broadcasted()
assert s.lens() == [1, 64, 3, 3]
assert s.strides() == [0, 1, 0, 0]
def test_create_shape_type():
s = migraphx.shape(type='uint8')
assert s.type_string() == 'uint8_type'
assert s.type_size() == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment