Unverified Commit 002eb4e2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add C++ ability to construct operators by name (#616)



* Add make_op function

* Formatting

* Add more values

* Formatting

* Remove templates parse_conv functions

* Formatting

* Remove mat_mul template

* Formatting

* Reduce header includes

* Fix compiling for gpu

* Formatting

* Use make_op in lowering

* Formatting

* Sort lines

* Formatting

* Add more tests

* Formatting

* Fix tidy error

* Formatting

* Add const refs

* Add explicit this

* Add more const refs

* Sort the program

* Remove commented out code

* Formatting

* Infer gpu prefix

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 56b3bf58
...@@ -20,6 +20,7 @@ add_library(migraphx ...@@ -20,6 +20,7 @@ add_library(migraphx
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
make_op.cpp
msgpack.cpp msgpack.cpp
program.cpp program.cpp
quantization.cpp quantization.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -23,7 +23,7 @@ struct quant_dot ...@@ -23,7 +23,7 @@ struct quant_dot
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::string name() const { return "quant_dot"; } std::string name() const { return "quant_dot"; }
......
...@@ -47,7 +47,7 @@ value to_value_impl(rank<1>, const std::pair<T, U>& x) ...@@ -47,7 +47,7 @@ value to_value_impl(rank<1>, const std::pair<T, U>& x)
template <class T> template <class T>
auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
{ {
value result; value result = value::array{};
for(auto&& y : x) for(auto&& y : x)
{ {
auto e = to_value(y); auto e = to_value(y);
...@@ -59,7 +59,7 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) ...@@ -59,7 +59,7 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
value to_value_impl(rank<3>, const T& x) value to_value_impl(rank<3>, const T& x)
{ {
value result; value result = value::object{};
reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); }); reflect_each(x, [&](auto&& y, std::string name) { result.emplace(name, to_value(y)); });
return result; return result;
} }
......
...@@ -73,6 +73,13 @@ struct value_converter<std::pair<T, U>> ...@@ -73,6 +73,13 @@ struct value_converter<std::pair<T, U>>
}; };
namespace detail { namespace detail {
template <class To, class Key, class From>
auto try_convert_value_impl(rank<2>, const std::pair<Key, From>& x)
-> decltype(value_converter<To>::apply(x.second))
{
return value_converter<To>::apply(x.second);
}
template <class To, class From> template <class To, class From>
auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x)) auto try_convert_value_impl(rank<1>, const From& x) -> decltype(value_converter<To>::apply(x))
{ {
...@@ -89,7 +96,7 @@ To try_convert_value_impl(rank<0>, const From& x) ...@@ -89,7 +96,7 @@ To try_convert_value_impl(rank<0>, const From& x)
template <class To, class From> template <class To, class From>
To try_convert_value(const From& x) To try_convert_value(const From& x)
{ {
return detail::try_convert_value_impl<To>(rank<1>{}, x); return detail::try_convert_value_impl<To>(rank<2>{}, x);
} }
struct value struct value
...@@ -159,6 +166,26 @@ struct value ...@@ -159,6 +166,26 @@ struct value
using is_pickable = using is_pickable =
std::integral_constant<bool, (std::is_arithmetic<T>{} and not std::is_pointer<T>{})>; std::integral_constant<bool, (std::is_arithmetic<T>{} and not std::is_pointer<T>{})>;
template <class T>
using range_value = std::decay_t<decltype(std::declval<T>().end(), *std::declval<T>().begin())>;
template <class T>
using is_generic_range =
std::integral_constant<bool,
(std::is_convertible<range_value<T>, value>{} and
not std::is_convertible<T, array>{} and
not std::is_convertible<T, object>{})>;
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const T& r) : value(from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value(const std::string& pkey, const T& r) : value(pkey, from_values(r))
{
}
template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_pickable<T>{})>
value(T i) : value(pick<T>{i}) value(T i) : value(pick<T>{i})
{ {
...@@ -176,6 +203,11 @@ struct value ...@@ -176,6 +203,11 @@ struct value
{ {
return *this = pick<T>{rhs}; // NOLINT return *this = pick<T>{rhs}; // NOLINT
} }
template <class T, MIGRAPHX_REQUIRES(is_generic_range<T>{})>
value& operator=(T rhs)
{
return *this = from_values(rhs); // NOLINT
}
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
...@@ -214,6 +246,10 @@ struct value ...@@ -214,6 +246,10 @@ struct value
const value& operator[](std::size_t i) const; const value& operator[](std::size_t i) const;
value& operator[](const std::string& pkey); value& operator[](const std::string& pkey);
void clear();
void resize(std::size_t n);
void resize(std::size_t n, const value& v);
std::pair<value*, bool> insert(const value& v); std::pair<value*, bool> insert(const value& v);
value* insert(const value* pos, const value& v); value* insert(const value* pos, const value& v);
...@@ -294,6 +330,14 @@ struct value ...@@ -294,6 +330,14 @@ struct value
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
private: private:
template <class T>
std::vector<value> from_values(const T& r)
{
std::vector<value> v;
std::transform(
r.begin(), r.end(), std::back_inserter(v), [&](auto&& e) { return value(e); });
return v;
}
type_t get_type() const; type_t get_type() const;
std::shared_ptr<value_base_impl> x; std::shared_ptr<value_base_impl> x;
std::string key; std::string key;
......
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
{
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
op.from_value(w);
return op;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
#include <rocblas.h>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/op/elu.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/gpu/argmax.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/gpu/argmin.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp>
#include <migraphx/gpu/cosh.hpp>
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/gpu/asin.hpp>
#include <migraphx/gpu/acos.hpp>
#include <migraphx/gpu/atan.hpp>
#include <migraphx/gpu/asinh.hpp>
#include <migraphx/gpu/acosh.hpp>
#include <migraphx/gpu/atanh.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/max.hpp>
#include <migraphx/gpu/min.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/ceil.hpp>
#include <migraphx/gpu/floor.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_max.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/reduce_min.hpp>
#include <migraphx/gpu/reduce_prod.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp> #include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/recip.hpp> #include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/rnn_variable_seq_lens.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/iterator_for.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -136,61 +95,59 @@ struct miopen_apply ...@@ -136,61 +95,59 @@ struct miopen_apply
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu); add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
add_generic_op<hip_add>("add"); add_generic_op("acos");
add_generic_op<hip_sub>("sub"); add_generic_op("acosh");
add_generic_op<hip_exp>("exp"); add_generic_op("add");
add_generic_op<hip_erf>("erf"); add_generic_op("asin");
add_generic_op<hip_log>("log"); add_generic_op("asinh");
add_generic_op<hip_sin>("sin"); add_generic_op("atan");
add_generic_op<hip_cos>("cos"); add_generic_op("atanh");
add_generic_op<hip_tan>("tan"); add_generic_op("ceil");
add_generic_op<hip_sinh>("sinh"); add_generic_op("contiguous");
add_generic_op<hip_cosh>("cosh"); add_generic_op("cos");
add_generic_op<hip_tanh>("tanh"); add_generic_op("cosh");
add_generic_op<hip_asin>("asin"); add_generic_op("div");
add_generic_op<hip_acos>("acos"); add_generic_op("erf");
add_generic_op<hip_atan>("atan"); add_generic_op("exp");
add_generic_op<hip_asinh>("asinh"); add_generic_op("floor");
add_generic_op<hip_acosh>("acosh"); add_generic_op("log");
add_generic_op<hip_atanh>("atanh"); add_generic_op("max");
add_generic_op<hip_sqrt>("sqrt"); add_generic_op("min");
add_generic_op<hip_mul>("mul"); add_generic_op("mul");
add_generic_op<hip_div>("div"); add_generic_op("pow");
add_generic_op<hip_max>("max"); add_generic_op("prelu");
add_generic_op<hip_min>("min"); add_generic_op("recip");
add_generic_op<hip_rsqrt>("rsqrt"); add_generic_op("relu");
add_generic_op<hip_round>("round"); add_generic_op("round");
add_generic_op<hip_pow>("pow"); add_generic_op("rsqrt");
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op("sigmoid");
add_generic_op<hip_relu>("relu"); add_generic_op("sign");
add_generic_op<hip_prelu>("prelu"); add_generic_op("sin");
add_generic_op<hip_sign>("sign"); add_generic_op("sinh");
add_generic_op<hip_sigmoid>("sigmoid"); add_generic_op("sqdiff");
add_generic_op<hip_ceil>("ceil"); add_generic_op("sqrt");
add_generic_op<hip_floor>("floor"); add_generic_op("sub");
add_generic_op<hip_recip>("recip"); add_generic_op("tan");
add_generic_op<miopen_contiguous>("contiguous"); add_generic_op("tanh");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op("argmax");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op("argmin");
add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op("clip");
add_extend_op<hip_argmax, op::argmax>("argmax"); add_extend_op("concat");
add_extend_op<hip_argmin, op::argmin>("argmin"); add_extend_op("convert");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op("logsoftmax");
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op("pad");
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op("reduce_max");
add_extend_op<hip_reduce_max, op::reduce_max>("reduce_max"); add_extend_op("reduce_mean");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean"); add_extend_op("reduce_min");
add_extend_op<hip_reduce_min, op::reduce_min>("reduce_min"); add_extend_op("reduce_prod");
add_extend_op<hip_reduce_prod, op::reduce_prod>("reduce_prod"); add_extend_op("reduce_sum");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum"); add_extend_op("rnn_var_sl_last_output");
add_extend_op<hip_rnn_var_sl_shift_output, op::rnn_var_sl_shift_output>( add_extend_op("rnn_var_sl_shift_output");
"rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_shift_sequence, op::rnn_var_sl_shift_sequence>( add_extend_op("softmax");
"rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(
"rnn_var_sl_last_output");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op(); add_lrn_op();
...@@ -379,28 +336,30 @@ struct miopen_apply ...@@ -379,28 +336,30 @@ struct miopen_apply
}); });
} }
template <class T> void add_generic_op(const std::string& name) { add_generic_op(name, "gpu::" + name); }
void add_generic_op(std::string name)
void add_generic_op(const std::string& op_name, const std::string& gpu_name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output); refs.push_back(output);
return prog->replace_instruction(ins, T{}, refs); return prog->replace_instruction(ins, make_op(gpu_name), refs);
}); });
} }
template <class T, class Op> void add_extend_op(const std::string& name) { add_extend_op(name, "gpu::" + name); }
void add_extend_op(std::string name)
void add_extend_op(const std::string& op_name, const std::string& gpu_name)
{ {
apply_map.emplace(name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = ins->get_operator();
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
refs.push_back(output); refs.push_back(output);
return prog->replace_instruction(ins, T{op}, refs); return prog->replace_instruction(ins, make_op(gpu_name, op.to_value()), refs);
}); });
} }
...@@ -472,7 +431,8 @@ struct miopen_apply ...@@ -472,7 +431,8 @@ struct miopen_apply
std::vector<float> zeros(s.elements(), 0.0f); std::vector<float> zeros(s.elements(), 0.0f);
auto l0 = prog->add_literal(literal(s, zeros)); auto l0 = prog->add_literal(literal(s, zeros));
auto output = insert_allocation(ins, s); auto output = insert_allocation(ins, s);
return prog->replace_instruction(ins, hip_sub{}, l0, ins->inputs().front(), output); return prog->replace_instruction(
ins, make_op("gpu::sub"), l0, ins->inputs().front(), output);
}); });
} }
}; };
......
...@@ -209,6 +209,14 @@ std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x) ...@@ -209,6 +209,14 @@ std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x)
return *a; return *a;
} }
std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Expected an array or object");
return *a;
}
value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key) value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key)
{ {
auto* a = if_array_impl(x); auto* a = if_array_impl(x);
...@@ -302,15 +310,29 @@ const value& value::at(const std::string& pkey) const ...@@ -302,15 +310,29 @@ const value& value::at(const std::string& pkey) const
{ {
auto* r = find(pkey); auto* r = find(pkey);
if(r == nullptr) if(r == nullptr)
MIGRAPHX_THROW("Not an object"); MIGRAPHX_THROW("Not an object for field: " + pkey);
if(r == end()) if(r == end())
MIGRAPHX_THROW("Key not found"); MIGRAPHX_THROW("Key not found: " + pkey);
return *r; return *r;
} }
value& value::operator[](std::size_t i) { return *(begin() + i); } value& value::operator[](std::size_t i) { return *(begin() + i); }
const value& value::operator[](std::size_t i) const { return *(begin() + i); } const value& value::operator[](std::size_t i) const { return *(begin() + i); }
value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; } value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; }
void value::clear() { get_array_throw(x).clear(); }
void value::resize(std::size_t n)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n);
}
void value::resize(std::size_t n, const value& v)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n, v);
}
std::pair<value*, bool> value::insert(const value& v) std::pair<value*, bool> value::insert(const value& v)
{ {
if(v.key.empty()) if(v.key.empty())
......
...@@ -44,6 +44,9 @@ function(add_test_command NAME EXE) ...@@ -44,6 +44,9 @@ function(add_test_command NAME EXE)
# --args $<TARGET_FILE:${EXE}> ${ARGN}) # --args $<TARGET_FILE:${EXE}> ${ARGN})
set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME}) set(TEST_DIR ${CMAKE_CURRENT_BINARY_DIR}/gdb/test_${NAME})
file(MAKE_DIRECTORY ${TEST_DIR}) file(MAKE_DIRECTORY ${TEST_DIR})
if (NOT EXISTS ${TEST_DIR})
message(FATAL_ERROR "Failed to create test directory: ${TEST_DIR}")
endif()
file(GENERATE OUTPUT "${TEST_DIR}/run.cmake" file(GENERATE OUTPUT "${TEST_DIR}/run.cmake"
CONTENT " CONTENT "
# Remove previous core dump # Remove previous core dump
......
...@@ -2014,7 +2014,7 @@ TEST_CASE(transpose_gather_test) ...@@ -2014,7 +2014,7 @@ TEST_CASE(transpose_gather_test)
auto prog = optimize_onnx("transpose_gather_test.onnx"); auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p == prog); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
......
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
...@@ -13,6 +15,35 @@ TEST_CASE(load_op) ...@@ -13,6 +15,35 @@ TEST_CASE(load_op)
} }
} }
TEST_CASE(make_op)
{
for(const auto& name : migraphx::get_operators())
{
auto op = migraphx::load_op(name);
CHECK(op == migraphx::make_op(name));
}
}
TEST_CASE(make_op_from_value1)
{
migraphx::operation x = migraphx::make_op(
"convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {2, 2}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}, {2, 2}, {2, 2}};
EXPECT(x == y);
}
TEST_CASE(make_op_from_value2)
{
migraphx::operation x = migraphx::make_op("convolution", {{"padding", {1, 1}}});
migraphx::operation y = migraphx::op::convolution{{1, 1}};
EXPECT(x == y);
}
TEST_CASE(make_op_invalid_key)
{
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
}
TEST_CASE(ops) TEST_CASE(ops)
{ {
auto names = migraphx::get_operators(); auto names = migraphx::get_operators();
......
...@@ -69,4 +69,35 @@ TEST_CASE(serialize_reflectable_type) ...@@ -69,4 +69,35 @@ TEST_CASE(serialize_reflectable_type)
EXPECT(v2 != v3); EXPECT(v2 != v3);
} }
TEST_CASE(serialize_empty_array)
{
std::vector<std::size_t> ints = {};
migraphx::value v = migraphx::to_value(ints);
EXPECT(v.is_array());
EXPECT(v.empty());
v.push_back(1);
EXPECT(v.size() == 1);
EXPECT(v.front().to<int>() == 1);
}
struct empty_struct
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return migraphx::pack();
}
};
TEST_CASE(serialize_empty_struct)
{
empty_struct es{};
migraphx::value v = migraphx::to_value(es);
EXPECT(v.is_object());
EXPECT(v.empty());
v["a"] = 1;
EXPECT(v.size() == 1);
EXPECT(v.at("a").to<int>() == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -563,4 +563,132 @@ TEST_CASE(print) ...@@ -563,4 +563,132 @@ TEST_CASE(print)
EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}"); EXPECT(ss.str() == "{1, {one: 1, two: 2}, {1, 2}, null}");
} }
TEST_CASE(value_clear)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.clear();
EXPECT(values.empty());
values.push_back(3);
EXPECT(values.size() == 1);
EXPECT(values.at(0).to<int>() == 3);
}
TEST_CASE(value_clear_non_array)
{
migraphx::value values = 1.0;
EXPECT(test::throws([&] { values.clear(); }));
}
TEST_CASE(value_clear_object)
{
migraphx::value values = {{"a", 1}, {"b", 2}};
EXPECT(values.is_object());
EXPECT(values.size() == 2);
values.clear();
EXPECT(values.empty());
values["c"] = 3;
EXPECT(values.size() == 1);
EXPECT(values.at("c").to<int>() == 3);
}
TEST_CASE(value_clear_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_clear_empty_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.empty());
values.clear();
EXPECT(values.empty());
}
TEST_CASE(value_resize)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5);
EXPECT(values.size() == 5);
EXPECT(values.at(3).is_null());
EXPECT(values.at(4).is_null());
}
TEST_CASE(value_resize_with_value)
{
migraphx::value values = {1, 2, 3};
EXPECT(values.is_array());
EXPECT(values.size() == 3);
values.resize(5, 7);
EXPECT(values.size() == 5);
EXPECT(values.at(3).to<int>() == 7);
EXPECT(values.at(4).to<int>() == 7);
}
TEST_CASE(value_resize_empty_array)
{
migraphx::value values = migraphx::value::array{};
EXPECT(values.is_array());
EXPECT(values.empty());
values.resize(3);
EXPECT(values.size() == 3);
EXPECT(values.at(0).is_null());
EXPECT(values.at(1).is_null());
EXPECT(values.at(2).is_null());
}
TEST_CASE(value_resize_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4); }));
}
TEST_CASE(value_resize_n_object)
{
migraphx::value values = migraphx::value::object{};
EXPECT(values.is_object());
EXPECT(test::throws([&] { values.resize(4, ""); }));
}
TEST_CASE(value_assign_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_construct_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values(v);
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_assign_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values{};
values = v;
EXPECT(values.to_vector<int>() == v);
}
TEST_CASE(value_init_from_vector)
{
std::vector<int> v = {1, 2, 3};
migraphx::value values = {{"a", v}};
EXPECT(values.at("a").to_vector<int>() == v);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment