Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads) ...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_ref TARGETS migraphx_ref
INCLUDE INCLUDE
......
...@@ -819,9 +819,9 @@ struct ref_apply ...@@ -819,9 +819,9 @@ struct ref_apply
void apply_pooling(instruction_ref ins) const void apply_pooling(instruction_ref ins) const
{ {
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<max_pool>{op}, ins->inputs());
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs()); mod->replace_instruction(ins, ref_pooling<avg_pool>{op}, ins->inputs());
} }
}; };
......
...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB TF_SRCS *.cpp) file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_tf ${TF_SRCS}) add_library(migraphx_tf ${TF_SRCS})
target_include_directories(migraphx_tf PRIVATE include) target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
......
...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -19,7 +19,12 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info, tf_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
op::pooling op{starts_with(opd.tf_name, "Max") ? "max" : "average"}; if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av"))
{
MIGRAPHX_THROW("tf pooling mode must be Max or Average");
}
op::pooling op{starts_with(opd.tf_name, "Max") ? op::pooling_mode::max
: op::pooling_mode::average};
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
......
...@@ -499,8 +499,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const ...@@ -499,8 +499,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF: {
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end()); std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
std::vector<half> data_half; std::vector<half> data_half;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -138,6 +139,7 @@ value::value(const std::string& pkey, const value& rhs) ...@@ -138,6 +139,7 @@ value::value(const std::string& pkey, const value& rhs)
{ {
} }
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
value::value(const char* i) : value(std::string(i)) {} value::value(const char* i) : value(std::string(i)) {}
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
...@@ -161,6 +163,12 @@ value::value(const char* i) : value(std::string(i)) {} ...@@ -161,6 +163,12 @@ value::value(const char* i) : value(std::string(i)) {}
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; } const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
value& value::operator=(const char* c)
{
*this = std::string{c};
return *this;
}
value& value::operator=(std::nullptr_t) value& value::operator=(std::nullptr_t)
{ {
x = nullptr; x = nullptr;
...@@ -410,25 +418,12 @@ value value::with_key(const std::string& pkey) const ...@@ -410,25 +418,12 @@ value value::with_key(const std::string& pkey) const
return result; return result;
} }
template <class F, class T, class U, class Common = typename std::common_type<T, U>::type> template <class T>
auto compare_common_impl( const T& compare_decay(const T& x)
rank<1>, F f, const std::string& keyx, const T& x, const std::string& keyy, const U& y)
{
return f(std::forward_as_tuple(keyx, Common(x)), std::forward_as_tuple(keyy, Common(y)));
}
template <class F>
auto compare_common_impl(
rank<1>, F f, const std::string& keyx, std::nullptr_t, const std::string& keyy, std::nullptr_t)
{
return f(std::forward_as_tuple(keyx, 0), std::forward_as_tuple(keyy, 0));
}
template <class F, class T, class U>
auto compare_common_impl(rank<0>, F, const std::string&, const T&, const std::string&, const U&)
{ {
return false; return x;
} }
int compare_decay(std::nullptr_t) { return 0; }
template <class F> template <class F>
bool compare(const value& x, const value& y, F f) bool compare(const value& x, const value& y, F f)
...@@ -436,7 +431,11 @@ bool compare(const value& x, const value& y, F f) ...@@ -436,7 +431,11 @@ bool compare(const value& x, const value& y, F f)
bool result = false; bool result = false;
x.visit_value([&](auto&& a) { x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) { y.visit_value([&](auto&& b) {
result = compare_common_impl(rank<1>{}, f, x.get_key(), a, y.get_key(), b); if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
}); });
}); });
return result; return result;
...@@ -455,11 +454,16 @@ bool operator==(const value& x, const value& y) ...@@ -455,11 +454,16 @@ bool operator==(const value& x, const value& y)
return false; return false;
return compare(x, y, std::equal_to<>{}); return compare(x, y, std::equal_to<>{});
} }
bool operator!=(const value& x, const value& y) { return !(x == y); } bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y) { return compare(x, y, std::less<>{}); } bool operator<(const value& x, const value& y)
bool operator<=(const value& x, const value& y) { return x == y or x < y; } {
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; } bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return x == y or x > y; } bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; } void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
......
...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME) ...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
file(GLOB TESTS *.cpp) file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp)
foreach(TEST ${TESTS}) foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -100,7 +100,7 @@ endforeach() ...@@ -100,7 +100,7 @@ endforeach()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
# gpu tests # gpu tests
file(GLOB GPU_TESTS gpu/*.cpp) file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp)
foreach(TEST ${GPU_TESTS}) foreach(TEST ${GPU_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp) ...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
foreach(ONNX_TEST ${ONNX_TESTS}) foreach(ONNX_TEST ${ONNX_TESTS})
get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
...@@ -160,7 +160,7 @@ function(test_header NAME HEADER) ...@@ -160,7 +160,7 @@ function(test_header NAME HEADER)
endfunction() endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS ${ARGN}) file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN})
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
......
#include <migraphx/any_ptr.hpp>
#include <test.hpp>
TEST_CASE(test_int_id)
{
int i = 1;
migraphx::any_ptr p = &i;
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
}
TEST_CASE(test_int_name)
{
int i = 1;
void* vp = &i;
migraphx::any_ptr p{vp, migraphx::get_type_name(i)};
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(float{})); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -10,7 +10,11 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,7 +10,11 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(shape_assign)
{
auto s1_cpp = migraphx::shape{migraphx_shape_float_type, {1, 3}};
std::vector<size_t> lens{2, 3};
// handle ptr is const, workaround to construct shape using C API
migraphx_shape_t s2;
migraphx_shape_create(&s2, migraphx_shape_float_type, lens.data(), lens.size());
auto s2_cpp = migraphx::shape(s2, migraphx::own{});
CHECK(bool{s1_cpp != s2_cpp});
// use C++ API for assignment
s1_cpp.assign_to_handle(s2);
CHECK(bool{s1_cpp == s2_cpp});
auto s3_cpp = migraphx::shape{migraphx_shape_float_type, lens};
// use C API for assignment
migraphx_shape_assign_to(s2, s3_cpp.get_handle_ptr());
CHECK(bool{s2_cpp == s3_cpp});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct simple_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "simple_custom_op"; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
return inputs.front();
}
};
TEST_CASE(register_custom_op)
{
simple_custom_op simple_op;
migraphx::register_experimental_custom_op(simple_op);
auto op = migraphx::operation("simple_custom_op");
EXPECT(op.name() == "simple_custom_op");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run) ...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
} }
TEST_CASE(load_and_run_ctx)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto ctx = p.experimental_get_context();
p.eval(pp);
ctx.finish();
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
auto run_prog = [&](auto cond) { auto run_prog = [&](auto cond) {
......
#include <migraphx/migraphx.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <numeric>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(add_op)
{
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
auto x = m.add_parameter("x", param_shape);
auto y = m.add_parameter("y", param_shape);
auto add_op = migraphx::operation("add");
auto r = m.add_instruction(add_op, {x, y});
m.add_return({r});
// run on ref target
p.compile(migraphx::target("ref"));
migraphx::program_parameters pp;
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
pp.add("x", migraphx::argument(param_shape, x_data.data()));
pp.add("y", migraphx::argument(param_shape, y_data.data()));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<float> expected(9, 0);
CHECK(bool(output == migraphx::argument(param_shape, expected.data())));
}
TEST_CASE(if_then_else_op)
{
migraphx::shape param_shape{migraphx_shape_float_type, {3, 3}};
migraphx::shape cond_s{migraphx_shape_bool_type};
auto create_program = [&]() {
migraphx::program p;
auto mm = p.get_main_module();
auto cond = mm.add_parameter("cond", cond_s);
auto x = mm.add_parameter("x", param_shape);
auto y = mm.add_parameter("y", param_shape);
auto then_mod = p.create_module("If_0_if");
auto x_identity = then_mod.add_instruction(migraphx::operation("identity"), {x});
then_mod.add_return({x_identity});
auto else_mod = p.create_module("If_0_else");
auto y_identity = else_mod.add_instruction(migraphx::operation("identity"), {y});
else_mod.add_return({y_identity});
auto if_ins = mm.add_instruction(migraphx::operation("if"), {cond}, {then_mod, else_mod});
auto get_tuple_op = migraphx::operation("get_tuple_elem", "{index: 0}");
auto ret = mm.add_instruction(get_tuple_op, {if_ins});
mm.add_return({ret});
return p;
};
std::vector<float> x_data(9, 1);
std::vector<float> y_data(9, -1);
auto x_arg = migraphx::argument(param_shape, x_data.data());
auto y_arg = migraphx::argument(param_shape, y_data.data());
auto run_prog = [&](bool cond) {
auto p = create_program();
p.compile(migraphx::target("ref"));
auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs;
};
// then branch
auto then_res = run_prog(true);
CHECK(bool{then_res[0] == x_arg});
// else branch
auto else_res = run_prog(false);
CHECK(bool{else_res[0] == y_arg});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
TEST_CASE(load_save_default) TEST_CASE(load_save_default)
{ {
std::string filename = "migraphx_api_load_save.dat"; std::string filename = "migraphx_api_load_save.mxr";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes(); auto s1 = p1.get_output_shapes();
......
...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather) ...@@ -135,4 +135,52 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(dead_instruction)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}), data);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3}}}),
data);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes( migraphx::run_passes(
...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal) ...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test_submodule)
{
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
auto create_program = [&](bool remove_literal = false) {
migraphx::program p;
std::vector<bool> vc = {true};
std::vector<int64_t> vd = {3};
auto* mm = p.get_main_module();
auto in_cond = mm->add_parameter("ccond", sc);
auto in_val = mm->add_parameter("val", s);
auto b0 = mm->add_literal(migraphx::literal(sc, vc));
auto b1 = b0;
if(not(remove_literal))
b1 = mm->add_literal(migraphx::literal(sc, vc));
auto* body1 = p.create_module("loop_module1");
body1->add_parameter("#loop_module_in_1", sc);
auto in_v1 = body1->add_parameter("#loop_module_in_2", s);
auto l1 = body1->add_literal(migraphx::literal(si, vd));
auto ad1 = body1->add_instruction(migraphx::make_op("add"), l1, l1);
auto val1 = body1->add_instruction(migraphx::make_op("add"), in_v1, ad1);
auto cond1 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b0);
auto cond2 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body1->add_return({cond1, cond2, val1, val1});
auto* body2 = p.create_module("loop_module2");
body2->add_parameter("#loop_module_in_1", sc);
auto in_v2 = body2->add_parameter("#loop_module_in_2", s);
auto l2 = body2->add_literal(migraphx::literal(si, vd));
auto ad2 = body2->add_instruction(migraphx::make_op("add"), l2, l2);
auto val2 = body2->add_instruction(migraphx::make_op("add"), in_v2, ad2);
auto cond3 = body2->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body2->add_return({cond3, val2, val2});
auto loop1 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body1});
auto loop2 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body2});
mm->add_return({loop1, loop2});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program(true));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad) ...@@ -55,7 +55,8 @@ TEST_CASE(rewrite_pad)
auto l0 = create_im2col(padded_img, channels, m); auto l0 = create_im2col(padded_img, channels, m);
auto l1 = create_conv(padded_img, channels, m); auto l1 = create_conv(padded_img, channels, m);
auto l2 = m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}}), padded_img); auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), padded_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2); m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
auto s0 = l0->get_shape(); auto s0 = l0->get_shape();
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(gpu_context) TEST_CASE(gpu_context_serialize)
{ {
migraphx::context ctx = migraphx::gpu::context{0, 3}; migraphx::context ctx = migraphx::gpu::context{0, 3};
...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context) ...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context)
EXPECT(v == v1); EXPECT(v == v1);
} }
TEST_CASE(context_queue)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
EXPECT(ctx.get_queue().get<hipStream_t>() != nullptr);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <test.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/instruction.hpp>
migraphx::program create_gelu()
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3};
std::vector<float> data3 = {0.5};
migraphx::shape s0{migraphx::shape::float_type, {1}};
std::vector<size_t> x_dims{1, 1, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = mm->add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = mm->add_literal(migraphx::literal{s0, data1});
auto three_val = mm->add_literal(migraphx::literal{s0, data2});
auto half_val = mm->add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = mm->add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = mm->add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = mm->add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = mm->add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = mm->add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = mm->add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = mm->add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = mm->add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = mm->add_instruction(migraphx::op::mul{}, x, add_mul_half);
mm->add_return({mul_x});
return p;
}
TEST_CASE(enable_fast_gelu)
{
migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{});
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu"; }));
}
TEST_CASE(disable_fast_gelu)
{
migraphx::program p = create_gelu();
migraphx::compile_options options;
options.fast_math = false;
p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment