Commit bcb2c0a4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'test_branch_for_ort2' into branch_for_ort2

parents 3c95b34d 401d0f68
......@@ -12,31 +12,31 @@ shape
.. py:method:: type()
An integer that represents the type
An integer that represents the type.
:rtype: int
.. py:method:: lens()
A list of the lengths of the shape
A list of the lengths of the shape.
:rtype: list[int]
.. py:method:: strides()
A list of the strides of the shape
A list of the strides of the shape.
:rtype: list[int]
.. py:method:: elements()
The number of elements in the shape
The number of elements in the shape.
:rtype: int
.. py:method:: bytes()
The number of bytes the shape uses
The number of bytes the shape uses.
:rtype: int
......@@ -102,20 +102,25 @@ argument
Generate an argument with random data.
:param shape s: Shape of argument to generate.
:param int seed: The seed used for random number generation
:param int seed: The seed used for random number generation.
:rtype: argument
.. py:function:: fill_argument(s, value)
Fill argument of shape s with value.
:param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument.
:rtype argument
target
------
.. py:class:: target()
This represents the compiliation target.
This represents the compilation target.
.. py:function:: get_target(name)
......@@ -126,6 +131,37 @@ target
:rtype: target
module
------
.. py:method:: print()
Prints the contents of the module as list of instructions.
.. py:method:: add_instruction(op, args, mod_args=[])
Adds instruction into the module.
:param operation op: 'migraphx.op' to be added as instruction.
:param list[instruction] args: list of inputs to the op.
:param list[module] mod_args: optional list of module arguments to the operator.
:rtype instruction
.. py:method:: add_parameter(name, shape)
Adds a parameter to the module with provided name and shape.
:param str name: name of the parameter.
:param shape shape: shape of the parameter.
:rtype instruction
.. py:method:: add_return(args)
Adds a return instruction into the module.
:param list[instruction] args: instruction arguments which need to be returned from the module.
:rtype instruction
program
-------
......@@ -135,21 +171,27 @@ program
.. py:method:: clone()
Make a copy of the program
Make a copy of the program.
:rtype: program
.. py:method:: get_parameter_names()
Get all the input arguments' or parameters' names to the program as a list.
:rtype list[str]
.. py:method:: get_parameter_shapes()
Get the shapes of all the input parameters in the program.
:rtype: dict[str, shape]
.. py:method:: get_shape()
.. py:method:: get_output_shapes()
Get the shape of the final output of the program.
Get the shapes of the final outputs of the program.
:rtype: shape
:rtype: list[shape]
.. py:method:: compile(t, offload_copy=True, fast_math=True)
......@@ -159,6 +201,19 @@ program
:param bool offload_copy: For targets with offloaded memory(such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory.
:param bool fast_math: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled.
.. py:method:: get_main_module()
Get main module of the program.
:rtype module
.. py:method:: create_module(name)
Create and add a module of provided name into the program.
:param str name : name of the new module.
:rtype module
.. py:method:: run(params)
Run the program.
......@@ -167,7 +222,11 @@ program
:type params: dict[str, argument]
:return: The result of the last instruction.
:rtype: argument
:rtype: list[argument]
.. py:method:: sort()
Sort the modules of the program such that instructions appear in topologically sorted order.
.. py:function:: quantize_fp16(prog, ins_names=["all"])
......@@ -190,10 +249,22 @@ program
:type ins_names: list[str]
op
--
.. py::class:: op(name, kwargs)
Construct an operation with name and arguments.
:param str name : name of the operation, must be supported by MIGraphX.
:param dict[str, any] kwargs: arguments to the operation.
:rtype operation
parse_onnx
----------
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false)
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10)
Load and parse an onnx file.
......@@ -202,20 +273,21 @@ parse_onnx
:param str map_input_dims: Explicitly specify the dims of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs.
:param int max_loop_iterations: Maximum iteration number for the loop operator.
:rtype: program
parse_tf
--------
.. py:function:: parse_tf(filename, is_nhwc=True, batch_size=1)
.. py:function:: parse_tf(filename, is_nhwc=True, batch_size=1, map_input_dims=dict(), output_names=[])
Load and parse an tensorflow protobuf file file.
:param str filename: Path to file.
:param bool is_nhwc: Use nhwc as default format.
:param str batch_size: default batch size to use (if not specified in protobuf).
:param dict[str, list[int]] map_input_dims: Optional arg to explictly specify dimensions of the inputs.
:param list[str] output_names: Optional argument specify names of the output nodes.
:rtype: program
load
......@@ -223,7 +295,7 @@ load
.. py:function:: load(filename, format='msgpack')
Load a MIGraphX program
Load a MIGraphX program.
:param str filename: Path to file.
:param str format: Format of file. Valid options are msgpack or json.
......@@ -235,7 +307,7 @@ save
.. py:function:: save(p, filename, format='msgpack')
Save a MIGraphX program
Save a MIGraphX program.
:param program p: Program to save.
:param str filename: Path to file.
......
......@@ -47,25 +47,25 @@ void auto_contiguous::apply(module& p) const
}
}
// if ops used as output param are alias 0, add a contiguous for the output
// so return outputs with standard shape
if(last->name() == "@return")
{
auto inputs = last->inputs();
for(auto ins : inputs)
{
if(ins->name() == "contiguous")
continue;
// // if ops used as output param are alias 0, add a contiguous for the output
// // so return outputs with standard shape
// if(last->name() == "@return")
// {
// auto inputs = last->inputs();
// for(auto ins : inputs)
// {
// if(ins->name() == "contiguous")
// continue;
auto ins_alias = ins->get_operator().output_alias({});
if(ins_alias == 0 and ins->get_shape().element_space() !=
ins->inputs().front()->get_shape().element_space())
{
auto cont_ins = p.insert_instruction(last, make_op("contiguous"), ins);
p.replace_instruction(ins, cont_ins);
}
}
}
// auto ins_alias = ins->get_operator().output_alias({});
// if(ins_alias == 0 and ins->get_shape().element_space() !=
// ins->inputs().front()->get_shape().element_space())
// {
// auto cont_ins = p.insert_instruction(last, make_op("contiguous"), ins);
// p.replace_instruction(ins, cont_ins);
// }
// }
// }
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -66,7 +66,7 @@ struct reduce_op : op_name<Derived>
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}, {"std_shape", true}};
return {{"normalize_axes", normalize}};
}
std::vector<int64_t> tune_axes(std::size_t n_dim) const
......
......@@ -362,7 +362,7 @@ struct value
v(this->get_##vt()); \
return; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE)
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, )
}
......@@ -434,6 +434,8 @@ struct value
void debug_print(bool show_type = false) const;
type_t get_type() const;
private:
template <class T>
std::vector<value> from_values(const T& r)
......@@ -443,7 +445,6 @@ struct value
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const;
std::shared_ptr<value_base_impl> x;
std::string key;
};
......
......@@ -3,6 +3,8 @@
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp>
......@@ -95,7 +97,6 @@ migraphx::value to_value(py::kwargs kwargs)
auto&& val = arg.second;
visit_py(val, [&](auto py_val) { v[key] = py_val; });
}
return v;
}
} // namespace migraphx
......@@ -256,13 +257,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module>(m, "module")
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::module>{})
.def("__ne__", std::not_equal_to<migraphx::module>{})
.def(
"add_instruction",
[](migraphx::module& mm,
const migraphx::operation& op,
std::vector<migraphx::instruction_ref>& args,
std::vector<migraphx::module*>& mod_args) {
return mm.add_instruction(op, args, mod_args);
},
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
return mm.add_parameter(name, shape);
},
py::arg("name"),
py::arg("shape"))
.def(
"add_return",
[](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
.def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes)
......@@ -277,11 +303,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto* mm = p.get_main_module();
return *mm;
})
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
[](migraphx::program& p, const std::string& name) { return p.create_module(name); },
py::arg("name"))
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::parameter_map pm;
......@@ -399,6 +425,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16",
&migraphx::quantize_fp16,
py::arg("prog"),
......
......@@ -120,17 +120,19 @@ struct find_nop_reshapes
void apply(module& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
// output of reshape and contiguous is standard, so no need to add another contiguous
// if the output is used an a ret value
if(ins->name() == "contiguous" and ins->name() != "contiguous" and ins->name() != "reshape")
{
auto& outputs = ins->outputs();
if(std::any_of(
outputs.begin(), outputs.end(), [&](auto o) { return o->name() == "@return"; }))
{
return;
}
}
// // output of reshape and contiguous is standard, so no need to add another contiguous
// // if the output is used an a ret value
// if(ins->name() == "contiguous" and ins->name() != "contiguous" and ins->name() !=
// "reshape")
// {
// auto& outputs = ins->outputs();
// if(std::any_of(
// outputs.begin(), outputs.end(), [&](auto o) { return o->name() == "@return";
// }))
// {
// return;
// }
// }
p.replace_instruction(ins, ins->inputs().front());
}
};
......
......@@ -35,8 +35,8 @@ struct half2_max
// in_data is in shared memory
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
__device__ __half2 block_reduce_half2(
__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
......@@ -55,7 +55,7 @@ block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int
}
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
softmax_kernel_half2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out);
......@@ -73,7 +73,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
}
auto batch_max =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
......@@ -82,7 +82,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
}
auto batch_sum =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
......@@ -92,8 +92,8 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
// in_data is in shared memory
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
__device__ __half block_reduce_half(
__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = block_size / 2; s > 0; s >>= 1)
......@@ -109,7 +109,7 @@ block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int b
}
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
softmax_kernel_half(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out);
......@@ -125,14 +125,16 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
in_data_reduce[i] = d;
}
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
auto batch_max =
block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i];
}
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
auto batch_sum =
block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
......@@ -161,7 +163,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size();
half2_block_size = half2_block_size / 4;
softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
softmax_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg.data(), batch_item_num, half2_block_size, result.data());
}
else
......
......@@ -4,6 +4,7 @@
#include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <unordered_map>
#include <utility>
......@@ -417,25 +418,12 @@ value value::with_key(const std::string& pkey) const
return result;
}
template <class F, class T, class U, class Common = typename std::common_type<T, U>::type>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, const T& x, const std::string& keyy, const U& y)
{
return f(std::forward_as_tuple(keyx, Common(x)), std::forward_as_tuple(keyy, Common(y)));
}
template <class F>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, std::nullptr_t, const std::string& keyy, std::nullptr_t)
{
return f(std::forward_as_tuple(keyx, 0), std::forward_as_tuple(keyy, 0));
}
template <class F, class T, class U>
auto compare_common_impl(rank<0>, F, const std::string&, const T&, const std::string&, const U&)
template <class T>
const T& compare_decay(const T& x)
{
return false;
return x;
}
int compare_decay(std::nullptr_t) { return 0; }
template <class F>
bool compare(const value& x, const value& y, F f)
......@@ -443,7 +431,11 @@ bool compare(const value& x, const value& y, F f)
bool result = false;
x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) {
result = compare_common_impl(rank<1>{}, f, x.get_key(), a, y.get_key(), b);
if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
});
});
return result;
......@@ -462,11 +454,16 @@ bool operator==(const value& x, const value& y)
return false;
return compare(x, y, std::equal_to<>{});
}
bool operator!=(const value& x, const value& y) { return !(x == y); }
bool operator<(const value& x, const value& y) { return compare(x, y, std::less<>{}); }
bool operator<=(const value& x, const value& y) { return x == y or x < y; }
bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y)
{
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return x == y or x > y; }
bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
......
......@@ -68,9 +68,9 @@ struct nop
{
static std::string as_string() { return ""; }
template <class T>
static decltype(auto) call(T&& x)
static auto call(T&& x)
{
return x;
return static_cast<T&&>(x);
}
};
......@@ -113,6 +113,33 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s;
}
template <class T>
const T& get_value(const T& x)
{
return x;
}
template <class T, class Operator = nop>
struct lhs_expression;
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs);
template <class T, class Operator>
lhs_expression<T, Operator> make_lhs_expression(T&& lhs, Operator);
// NOLINTNEXTLINE
#define TEST_EXPR_BINARY_OPERATOR(op, name) \
template <class V> \
auto operator op(const V& rhs2) const \
{ \
return make_expression(*this, rhs2, name{}); /* NOLINT */ \
}
// NOLINTNEXTLINE
#define TEST_EXPR_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
template <class T, class U, class Operator>
struct expression
{
......@@ -125,7 +152,12 @@ struct expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs, rhs); };
friend decltype(auto) get_value(const expression& e) { return e.value(); }
decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); };
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
};
// TODO: Remove rvalue references
......@@ -135,9 +167,6 @@ expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
template <class T, class Operator = nop>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<T> make_lhs_expression(T&& lhs)
......@@ -166,22 +195,12 @@ struct lhs_expression
return s;
}
decltype(auto) value() const { return Operator::call(lhs); }
// NOLINTNEXTLINE
#define TEST_LHS_BINARY_OPERATOR(op, name) \
template <class U> \
auto operator op(const U& rhs) const \
{ \
return make_expression(lhs, rhs, name{}); /* NOLINT */ \
}
friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); }
TEST_FOREACH_BINARY_OPERATORS(TEST_LHS_BINARY_OPERATOR)
decltype(auto) value() const { return Operator::call(get_value(lhs)); }
// NOLINTNEXTLINE
#define TEST_LHS_UNARY_OPERATOR(op, name) \
auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ }
TEST_FOREACH_UNARY_OPERATORS(TEST_LHS_UNARY_OPERATOR)
TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR)
TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR)
// NOLINTNEXTLINE
#define TEST_LHS_REOPERATOR(op) \
......@@ -223,6 +242,13 @@ auto make_predicate(const std::string& msg, F f)
return make_lhs_expression(predicate<F>{msg, f}, function{});
}
inline std::string as_string(bool x)
{
if(x)
return "true";
return "false";
}
template <class T>
std::string as_string(const T& x)
{
......@@ -627,6 +653,9 @@ inline void run(int argc, const char* argv[])
} // namespace test
// NOLINTNEXTLINE
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE
#define CHECK(...) \
test::failed( \
......@@ -634,7 +663,7 @@ inline void run(int argc, const char* argv[])
})
// NOLINTNEXTLINE
#define EXPECT(...) \
test::failed(test::capture{}->*__VA_ARGS__, \
test::failed(TEST_CAPTURE(__VA_ARGS__), \
#__VA_ARGS__, \
__PRETTY_FUNCTION__, \
__FILE__, \
......
......@@ -27,6 +27,7 @@ add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
......
import migraphx
def test_add_op():
p = migraphx.program()
mm = p.get_main_module()
param_shape = migraphx.shape(lens=[3, 3], type="float")
x = mm.add_parameter("x", param_shape)
y = mm.add_parameter("y", param_shape)
add_op = mm.add_instruction(migraphx.op("add"), [x, y])
mm.add_return([add_op])
p.compile(migraphx.get_target("ref"))
params = {}
params["x"] = migraphx.generate_argument(param_shape)
params["y"] = migraphx.generate_argument(param_shape)
output = p.run(params)[-1].tolist()
assert output == [
a + b for a, b in zip(params["x"].tolist(), params["y"].tolist())
]
def test_if_then_else():
param_shape = migraphx.shape(lens=[3, 3], type="float")
cond_shape = migraphx.shape(type="bool", lens=[1], strides=[0])
def create_program():
p = migraphx.program()
mm = p.get_main_module()
cond = mm.add_parameter("cond", cond_shape)
x = mm.add_parameter("x", param_shape)
y = mm.add_parameter("y", param_shape)
then_mod = p.create_module("If_0_if")
x_identity = then_mod.add_instruction(migraphx.op("identity"), [x])
then_mod.add_return([x_identity])
else_mod = p.create_module("If_0_else")
y_identity = else_mod.add_instruction(migraphx.op("identity"), [y])
else_mod.add_return([y_identity])
if_ins = mm.add_instruction(migraphx.op("if"), [cond],
[then_mod, else_mod])
ret = mm.add_instruction(migraphx.op("get_tuple_elem", **{"index": 0}),
[if_ins])
mm.add_return([ret])
return p
params = {}
params["x"] = migraphx.generate_argument(param_shape)
params["y"] = migraphx.generate_argument(param_shape)
def run_prog(cond):
p = create_program()
p.compile(migraphx.get_target("ref"))
params["cond"] = migraphx.fill_argument(cond_shape, cond)
output = p.run(params)[-1]
return output
assert run_prog(True) == params["x"]
assert run_prog(False) == params["y"]
if __name__ == "__main__":
test_add_op()
test_if_then_else()
......@@ -540,6 +540,14 @@ TEST_CASE(value_construct_object_string_mixed_value)
EXPECT(v.at("two").get_int64() == 2);
}
template <class Expression>
auto compare_predicate(const Expression& e)
{
bool result = e.value();
return test::make_predicate(test::as_string(e) + " => " + test::as_string(result),
[=] { return result; });
}
TEST_CASE(value_compare)
{
EXPECT(migraphx::value(1) == migraphx::value(1));
......@@ -553,6 +561,46 @@ TEST_CASE(value_compare)
EXPECT(migraphx::value(2) > migraphx::value(1));
EXPECT(migraphx::value(2) >= migraphx::value(1));
EXPECT(migraphx::value(1) >= migraphx::value(1));
EXPECT(migraphx::value(1) != migraphx::value("1"));
EXPECT(migraphx::value(1) != migraphx::value());
}
// NOLINTNEXTLINE
#define MIGRAPHX_VALUE_TEST_COMPARE(...) compare_predicate(TEST_CAPTURE(__VA_ARGS__))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED_IMPL(_, x, y) \
EXPECT(_(x <= y) or _(x >= y)); \
EXPECT(_(x < y) or _(x > y) or _(x == y)); \
EXPECT((_(x < y) or _(x > y)) == _(x != y)); \
EXPECT(_(x < y) == _(y > x)); \
EXPECT(_(x <= y) == _(y >= x)); \
EXPECT(_(x < y) != _(x >= y)); \
EXPECT(_(x > y) != _(x <= y)); \
EXPECT(_(x == y) != _(x != y))
// NOLINTNEXTLINE
#define EXPECT_TOTALLY_ORDERED(x, y) \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, x, y); \
EXPECT_TOTALLY_ORDERED_IMPL(MIGRAPHX_VALUE_TEST_COMPARE, y, x)
// NOLINTNEXTLINE(readability-function-size)
TEST_CASE(value_compare_ordered)
{
EXPECT_TOTALLY_ORDERED(migraphx::value(), migraphx::value());
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(1));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value(2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 1));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", 2));
EXPECT_TOTALLY_ORDERED(migraphx::value("key", 1), migraphx::value("key", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value("key1", 1), migraphx::value("key2", "2"));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{1}), migraphx::value(std::uint64_t{2}));
EXPECT_TOTALLY_ORDERED(migraphx::value(std::int64_t{2}), migraphx::value(std::uint64_t{1}));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value("1"));
EXPECT_TOTALLY_ORDERED(migraphx::value(1), migraphx::value());
}
TEST_CASE(value_to_from_string)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment