Commit 82bd8e2e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'lstm_operator' into seq2seq_example

parents 98cd353f 1ea5faef
...@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const ...@@ -27,6 +27,13 @@ void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
......
...@@ -52,7 +52,6 @@ struct rewrite_rnn ...@@ -52,7 +52,6 @@ struct rewrite_rnn
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const; const operation& actv_func3) const;
......
...@@ -786,15 +786,20 @@ struct onnx_parser ...@@ -786,15 +786,20 @@ struct onnx_parser
{ {
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); }); vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(fn) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions. // bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
...@@ -915,12 +920,15 @@ struct onnx_parser ...@@ -915,12 +920,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
...@@ -1085,12 +1093,15 @@ struct onnx_parser ...@@ -1085,12 +1093,15 @@ struct onnx_parser
} }
} }
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) { if(std::any_of(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0) return (map_actv_funcs.count(name) == 0);
{ }))
MIGRAPHX_THROW("LSTM: activation function " + std::string(name) + " not supported"); {
} auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
}); return (map_actv_funcs.count(name) == 0);
});
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
...@@ -1342,7 +1353,15 @@ struct onnx_parser ...@@ -1342,7 +1353,15 @@ struct onnx_parser
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()}; {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half.begin(), data_half.end()};
}
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
......
...@@ -348,13 +348,17 @@ argument generic_eval(const program& p, ...@@ -348,13 +348,17 @@ argument generic_eval(const program& p,
} }
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
results.emplace(ins, trace(ins, [&] { results.emplace(
auto param_name = ins, trace(ins, [&] {
any_cast<builtin::param>(ins->get_operator()).parameter; auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
if(not contains(params, param_name)) if(not contains(params, param_name))
MIGRAPHX_THROW("Parameter not found: " + param_name); MIGRAPHX_THROW("Parameter not found: " + param_name);
return params.at(param_name); auto param = params.at(param_name);
})); if(param.get_shape() != ins->get_shape())
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
return param;
}));
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
......
...@@ -738,25 +738,19 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -738,25 +738,19 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward{}; instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse{}; instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
} }
else
{
pph_forward = prog.add_literal(migraphx::literal{pph_shape, pph_data});
pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret_forward = lstm_cell( auto ret_forward = lstm_cell(
true, true,
prog, prog,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward}, {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -766,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -766,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
prog, prog,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
lstm_op.input_forget,
actv_funcs.at(3), actv_funcs.at(3),
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
...@@ -830,21 +823,16 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -830,21 +823,16 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph{}; instruction_ref pph = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph = args[7]; pph = args[7];
} }
else
{
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, prog,
ins, ins,
{args[0], w, r, bias, ih, ic, pph}, {args[0], w, r, bias, ih, ic, pph},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -901,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -901,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const const operation& actv_func3) const
...@@ -991,18 +978,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -991,18 +978,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
} }
// peep hole // peep hole
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); instruction_ref pphi_brcst{};
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); instruction_ref ppho_brcst{};
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); instruction_ref pphf_brcst{};
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); if(pph != prog.end())
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); {
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
...@@ -1013,9 +1007,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1013,9 +1007,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
...@@ -1025,9 +1022,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1025,9 +1022,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); if(pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
...@@ -1053,9 +1053,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1053,9 +1053,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
......
...@@ -113,13 +113,18 @@ endif() ...@@ -113,13 +113,18 @@ endif()
# Onnx test # Onnx test
set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx) set(TEST_ONNX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_executable(test_onnx ${TEST_ONNX_DIR}/onnx_test.cpp) file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
rocm_clang_tidy_check(test_onnx) foreach(ONNX_TEST ${ONNX_TESTS})
target_link_libraries(test_onnx migraphx_onnx) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
target_include_directories(test_onnx PUBLIC include) set(TEST_NAME test_${BASE_NAME})
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST})
add_dependencies(tests test_onnx) rocm_clang_tidy_check(${TEST_NAME})
add_dependencies(check test_onnx) target_link_libraries(${TEST_NAME} migraphx_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
endforeach()
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py) add_subdirectory(py)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp>
float sigmoid(float x) { return 1 / (1 + expf(-x)); } float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test) ...@@ -1375,4 +1376,22 @@ TEST_CASE(pad_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(fp16_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {1}};
migraphx::half a{1.5};
migraphx::half b{2.5};
migraphx::half c{4.0};
auto l0 = p.add_literal(migraphx::literal{s, {a}});
auto l1 = p.add_literal(migraphx::literal{s, {b}});
p.add_instruction(migraphx::op::add{}, l0, l1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify_range(results_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2120,4 +2120,1242 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2120,4 +2120,1242 @@ TEST_CASE(gru_bidirectional_actv_funcs)
} }
} }
TEST_CASE(lstm_forward)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815,
0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085,
0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755,
-0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156,
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.111454,
0.247794,
0.471087,
-0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_forward_more)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({});
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048,
0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633,
-0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544,
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// forward, 8 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168,
0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906,
-0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149,
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// seq_len = 1
{
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.079753,
-0.289854,
0.160043,
0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(lstm_reverse)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 1 actv function
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs =
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last hidden state as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// sequence length is 1, contenation of hidden state as program output
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
// 3 args, 0 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(
migraphx::op::lstm{
hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 1 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.227861, 0.328562, 0.277867, 0.272945, 0.204389, 0.296123, 0.223834, 0.311113,
0.424666, 0.173974, 0.40628, 0.286631, 0.246078, 0.199709, 0.303753, 0.301178,
0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186,
0.339438, 0.29655, 0.331832, 0.242338, 0.409384, 0.236272, 0.306045, 0.26269,
0.261246, 0.334357, 0.23622, 0.245288, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.374123, 0.283167, 0.377129, 0.245726, 0.444712, 0.203168, 0.411446, 0.269965,
0.172792, 0.296224, 0.17319, 0.352547, 0.310306, 0.262902, 0.276964, 0.295002,
0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262,
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 2 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs =
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 4 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 5 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, 6 actv func
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -128,7 +128,7 @@ TEST_CASE(print_test) ...@@ -128,7 +128,7 @@ TEST_CASE(print_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, x, two); p.add_instruction(sum_op{}, x, two);
...@@ -142,8 +142,8 @@ TEST_CASE(param_test) ...@@ -142,8 +142,8 @@ TEST_CASE(param_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval( auto result = p.eval(
...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test) ...@@ -156,8 +156,8 @@ TEST_CASE(param_error_test)
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int32_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int32_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>( EXPECT(test::throws<migraphx::exception>(
...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test) ...@@ -167,6 +167,22 @@ TEST_CASE(param_error_test)
"Parameter not found: y")); "Parameter not found: y"));
} }
TEST_CASE(param_shape_error_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 2}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 2}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{2}.get_argument()}});
},
"Incorrect shape"));
}
TEST_CASE(replace_test) TEST_CASE(replace_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -825,6 +825,22 @@ struct test_contiguous ...@@ -825,6 +825,22 @@ struct test_contiguous
} }
}; };
struct test_eliminate_contiguous
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto seq = p.add_parameter("seq", s);
std::vector<int64_t> perm{0, 2, 1, 3};
auto tran_seq = p.add_instruction(migraphx::op::transpose{perm}, seq);
std::vector<int64_t> out_shape{0, 0, -1};
p.add_instruction(migraphx::op::reshape{out_shape}, tran_seq);
return p;
}
};
struct test_transpose struct test_transpose
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -2114,6 +2130,715 @@ struct test_gru_bidirct_default_actv1 ...@@ -2114,6 +2130,715 @@ struct test_gru_bidirct_default_actv1
} }
}; };
struct test_lstm_forward_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_forward_hs
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
return p;
}
};
struct test_lstm_forward_3args_und
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und,
und,
und);
return p;
}
};
struct test_lstm_forward_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_seq1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_default_actv
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_forward_default_actv1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::lstm{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_reverse_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_reverse_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_reverse_3args_cell_output
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto hs = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
return p;
}
};
struct test_lstm_bidirct_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_lstm_bidirct_hs
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_bidirct_3args_und
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
und,
und,
und,
und,
und);
return p;
}
};
struct test_lstm_bidirct_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_seq1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_default_actv
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p;
}
};
struct test_lstm_bidirct_default_actv1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_lstm_bidirct_default_actv2
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
int main() int main()
{ {
verify_program<test_relu_lrn>(); verify_program<test_relu_lrn>();
...@@ -2169,6 +2894,7 @@ int main() ...@@ -2169,6 +2894,7 @@ int main()
verify_program<test_gemm_transposea>(); verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>(); verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>(); verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>(); verify_program<test_transpose>();
verify_program<test_batchnorm_inference>(); verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>(); verify_program<test_batchnorm_inference_2>();
...@@ -2204,4 +2930,22 @@ int main() ...@@ -2204,4 +2930,22 @@ int main()
verify_program<test_gru_bidirct_seq1>(); verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>(); verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>(); verify_program<test_gru_bidirct_default_actv1>();
verify_program<test_lstm_forward_last>();
verify_program<test_lstm_forward_hs>();
verify_program<test_lstm_forward_3args_und>();
verify_program<test_lstm_forward_3args>();
verify_program<test_lstm_forward_seq1>();
verify_program<test_lstm_forward_default_actv>();
verify_program<test_lstm_forward_default_actv1>();
verify_program<test_lstm_reverse_last>();
verify_program<test_lstm_reverse_3args>();
verify_program<test_lstm_reverse_3args_cell_output>();
verify_program<test_lstm_bidirct_last>();
verify_program<test_lstm_bidirct_hs>();
verify_program<test_lstm_bidirct_3args_und>();
verify_program<test_lstm_bidirct_3args>();
verify_program<test_lstm_bidirct_seq1>();
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
} }
add-fp16-example:m

0
12"Add test-add-fp16* 
*|B0* 
*B1Z
0


Z
1


b
2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment