Commit 82bd8e2e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'lstm_operator' into seq2seq_example

parents 98cd353f 1ea5faef
...@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const ...@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
......
...@@ -52,7 +52,6 @@ struct rewrite_rnn ...@@ -52,7 +52,6 @@ struct rewrite_rnn
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const; const operation& actv_func3) const;
......
...@@ -786,15 +786,20 @@ struct onnx_parser ...@@ -786,15 +786,20 @@ struct onnx_parser
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); }); vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(fn) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions. // bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
...@@ -915,12 +920,15 @@ struct onnx_parser ...@@ -915,12 +920,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
...@@ -1085,12 +1093,15 @@ struct onnx_parser ...@@ -1085,12 +1093,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("LSTM: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
...@@ -1342,7 +1353,15 @@ struct onnx_parser ...@@ -1342,7 +1353,15 @@ struct onnx_parser
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()}; {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()};
}
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
......
...@@ -348,13 +348,17 @@ argument generic_eval(const program& p, ...@@ -348,13 +348,17 @@ argument generic_eval(const program& p,
} }
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(
auto param_name = ins, trace(ins, [&] {
any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name); auto param = params.at(param_name);
})); if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
return param;
}));
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
......
...@@ -738,25 +738,19 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -738,25 +738,19 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward{}; instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse{}; instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
} }
else
{
pph_forward = prog.add_literal(migraphx::literal{pph_shape, pph_data});
pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret_forward = lstm_cell( auto ret_forward = lstm_cell(
true, true,
prog, prog,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward}, {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -766,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -766,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
prog, prog,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
lstm_op.input_forget,
actv_funcs.at(3), actv_funcs.at(3),
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
...@@ -830,21 +823,16 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -830,21 +823,16 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph{}; instruction_ref pph = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph = args[7]; pph = args[7];
} }
else
{
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, prog,
ins, ins,
{args[0], w, r, bias, ih, ic, pph}, {args[0], w, r, bias, ih, ic, pph},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -901,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -901,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const const operation& actv_func3) const
...@@ -991,18 +978,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -991,18 +978,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
} }
// peep hole // peep hole
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); instruction_ref pphi_brcst{};
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); instruction_ref ppho_brcst{};
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); instruction_ref pphf_brcst{};
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); if(pph != prog.end())
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); {
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
...@@ -1013,9 +1007,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1013,9 +1007,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
...@@ -1025,9 +1022,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1025,9 +1022,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); if(pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
...@@ -1053,9 +1053,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1053,9 +1053,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
......
...@@ -113,13 +113,18 @@ endif() ...@@ -113,13 +113,18 @@ endif()
# Onnx test # Onnx test
set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx) set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_executable(test_onnx ${TEST_ONNX_DIR}/onnx_test.cpp) file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
rocm_clang_tidy_check(test_onnx) foreach(ONNX_TEST ${ONNX_TESTS})
target_link_libraries(test_onnx migraphx_onnx) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
target_include_directories(test_onnx PUBLIC include) set(TEST_NAME test_${BASE_NAME})
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST})
add_dependencies(tests test_onnx) rocm_clang_tidy_check(${TEST_NAME})
add_dependencies(check test_onnx) target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
endforeach()
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp>
float sigmoid(float x) { return 1 / (1 + expf(-x)); } float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test) ...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(fp16_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {1}};
migraphx::half a{1.5};
migraphx::half b{2.5};
migraphx::half c{4.0};
auto l0 = p.add_literal(migraphx::literal{s, {a}});
auto l1 = p.add_literal(migraphx::literal{s, {b}});
p.add_instruction(migraphx::op::add{}, l0, l1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
...@@ -128,7 +128,7 @@ TEST_CASE(print_test) ...@@ -128,7 +128,7 @@ TEST_CASE(print_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two); p.add_instruction(sum_op{}, x, two);
...@@ -142,8 +142,8 @@ TEST_CASE(param_test) ...@@ -142,8 +142,8 @@ TEST_CASE(param_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval( auto result = p.eval(
...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test) ...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test) ...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test)
"Parameter not found: y")); "Parameter not found: y"));
} }
TEST_CASE(param_shape_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 2}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 2}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}});
},
"Incorrect shape"));
}
TEST_CASE(replace_test) TEST_CASE(replace_test)
{ {
migraphx::program p; migraphx::program p;
......
This diff is collapsed.
add-fp16-example:m

0
12"Add test-add-fp16* 
*|B0* 
*B1Z
0


Z
1


b
2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment