Unverified Commit 352c2465 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

fix issues with compiling lstm ops in fp16 mode (#1450)

Currently, quantizing a program with rnn layers to fp16 results in segmentation faults due to a "convert" operation being applied to an "undefined" instruction.

The following changes are implemented to fix this issue:

Added is_undefined method to the instruction class that returns true if all inputs to the instruction are from an undefined op.
Updated rewrite_rnn pass to use the new is_undefined method rather than checking ins->name()
Updated the dead_code_elimination pass to also use this new method rather than only checking the instruction name
parent 37c3c4a9
...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const ...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not contains({"undefined", "identity", "allocate"}, i->name())) not i->is_undefined())
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
......
...@@ -121,6 +121,8 @@ struct instruction ...@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const; bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
void finalize(context& ctx); void finalize(context& ctx);
......
...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod); std::replace(module_args.begin(), module_args.end(), old, new_mod);
} }
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
......
...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state // intial hidden state
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument // process intial hidden state, it is the 6th argument
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value // process initial cell value
instruction_ref ic_forward{}; instruction_ref ic_forward{};
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic_forward = m.insert_instruction( ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward = m.end(); instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end(); instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph_forward = m.insert_instruction( pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// initial hidden state // initial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value // initial cell value
instruction_ref ic{}; instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic = args[6]; ic = args[6];
} }
...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph = m.end(); instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph = args[7]; pph = args[7];
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include "test.hpp"
TEST_CASE(check_undefined)
{
migraphx::module m;
auto und = m.add_instruction(migraphx::make_op("undefined"));
auto cov = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), und);
auto abs = m.add_instruction(migraphx::make_op("abs"), cov);
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
auto lit = m.add_literal(migraphx::literal(xs, datax));
auto mul = m.add_instruction(migraphx::make_op("mul"), lit, lit);
EXPECT(und->is_undefined());
EXPECT(cov->is_undefined());
EXPECT(abs->is_undefined());
EXPECT(not lit->is_undefined());
EXPECT(not mul->is_undefined());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -932,6 +933,90 @@ TEST_CASE(rnn_bidirectional) ...@@ -932,6 +933,90 @@ TEST_CASE(rnn_bidirectional)
} }
} }
TEST_CASE(rnn_fp16)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> r_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> bias_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto seq_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), seq);
auto w_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w);
auto r_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), r);
auto out_hs = mm->add_instruction(
migraphx::make_op("rnn",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip}}),
seq_half,
w_half,
r_half);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
p.compile(migraphx::ref::target{});
auto last_output = p.eval({}).back();
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold, 5e4));
}
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
{ {
std::size_t batch_size = 2; std::size_t batch_size = 2;
...@@ -2797,6 +2882,116 @@ TEST_CASE(gru_bidirectional_seq_1) ...@@ -2797,6 +2882,116 @@ TEST_CASE(gru_bidirectional_seq_1)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
TEST_CASE(gru_fp16)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto seq_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), seq);
auto w_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w);
auto r_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), r);
auto bias_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), bias);
auto und_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), und);
auto ih_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), ih);
mm->add_instruction(
migraphx::make_op("gru",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("sigmoid"), migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"linear_before_reset", 1}}),
seq_half,
w_half,
r_half,
bias_half,
und_half,
ih_half);
p.compile(migraphx::ref::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238,
-0.6873896, 0.16144305, 0.31932795, 0.6104771, 0.79759157,
-0.31791314, 0.5249062, 0.08800987, 0.46404213, -0.11872687,
-0.26210734, 0.34448293, -0.0176422, 0.48523626, 0.60002893,
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold, 5e4));
}
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
...@@ -4651,4 +4846,148 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4651,4 +4846,148 @@ TEST_CASE(lstm_bidirectional_actv_func)
} }
} }
TEST_CASE(lstm_fp16)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto seq_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), seq);
auto w_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w);
auto r_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), r);
auto bias_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), bias);
auto ih_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), ih);
auto ic_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), ic);
auto und_half = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), und);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq_half,
w_half,
r_half,
bias_half,
und_half,
ih_half,
ic_half,
und_half);
p.compile(migraphx::ref::target{});
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815,
0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085,
0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755,
-0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156,
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold, 5e4));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment