Unverified Commit 90200619 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Rnn variable seq lengths (#517)



* code backup

* clang format

* fix compiling errors

* clang format

* rename a few files

* rename a few files

* fix variable bugs

* clang format

* add an operator to shift input sequences

* clang format

* fixed a bug

* clang format

* fixed a bug

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* refine code related lstm operator optimization

* clang format

* fix various bugs

* clang format

* fixed a bug in rewrite_lstm

* clang format

* fixed another bug

* refine two operator names

* clang format

* refine file names

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* fixed review comments

* clang format

* add unit tests

* clang format

* add unit tests

* clang format

* refine unit tests for better coverage

* clang format

* fixed a bug

* fix cppcheck error

* fix review comments

* clang format

* rename two operators according to review comments

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments

* fix a cppcheck error

* clang format

* fix review comments

* clang format
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 369b9f60
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_CELL_OUTPUT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct lstm_last_cell_output
struct rnn_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
std::string name() const { return "rnn_last_cell_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_OUTPUT_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_LAST_HS_OUTPUT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_last_output
struct rnn_last_hs_output
{
std::string name() const { return "rnn_last_output"; }
std::string name() const { return "rnn_last_hs_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
......
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_VAR_SL_LAST_OUTPUT_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_var_sl_last_output
{
rnn_direction direction = rnn_direction::forward;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.direction, "direction"));
}
std::string name() const { return "rnn_var_sl_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_OPERATORS_RNN_VARIABLE_SEQ_LENS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct rnn_var_sl_shift_output
{
std::string output_name = "hidden_states";
rnn_direction direction = rnn_direction::forward;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.output_name, "hidden_states"), f(self.direction, "direction"));
}
std::string name() const { return "rnn_var_sl_shift_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs[0];
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto batch_id = idx[2];
auto d = idx[1];
auto t = idx[0];
auto sl = seq_lens[batch_id];
value_type val = value_type{0};
if(t < sl)
{
auto in_idx = idx;
int offset = (direction == rnn_direction::reverse or d == 1) ? 1 : 0;
in_idx[0] += offset * (max_len - sl);
val = input(in_idx.begin(), in_idx.end());
}
output(idx.begin(), idx.end()) = val;
});
});
});
return result;
}
};
struct rnn_var_sl_shift_sequence
{
std::string name() const { return "rnn_var_sl_shift_sequence"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs[0];
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
int64_t max_len = static_cast<int64_t>(output_shape.lens()[0]);
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(output)::value_type;
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto b = idx[1];
auto t = idx[0];
auto sl = seq_lens[b];
value_type val = value_type{0};
if(t >= max_len - sl)
{
auto in_idx = idx;
in_idx[0] -= (max_len - sl);
val = input(in_idx.begin(), in_idx.end());
}
output(idx.begin(), idx.end()) = val;
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -66,7 +66,9 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
......
......@@ -6,6 +6,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -57,6 +58,23 @@ struct rewrite_rnn
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(program& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -1486,7 +1486,7 @@ struct onnx_parser
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
......@@ -1608,11 +1608,96 @@ struct onnx_parser
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
return {hidden_states, last_output};
}
void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv_func_names)
{
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(actv_func_names.size())
{
case 1:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0),
actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2)};
break;
case 4:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(3),
actv_func_names.at(3)};
break;
case 5:
actv_func_names = {actv_func_names.at(0),
actv_func_names.at(1),
actv_func_names.at(2),
actv_func_names.at(3),
actv_func_names.at(4),
actv_func_names.at(4)};
break;
default: break;
}
}
else
{
switch(actv_func_names.size())
{
case 1:
actv_func_names = {
actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
actv_func_names = {
actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)};
break;
default: break;
}
}
}
std::vector<instruction_ref>
parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
......@@ -1664,83 +1749,7 @@ struct onnx_parser
});
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
lstm_actv_functions(dirct, vec_names);
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
......@@ -1779,11 +1788,10 @@ struct onnx_parser
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
auto last_cell_output = prog.add_instruction(op::rnn_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
......
This diff is collapsed.
......@@ -18,6 +18,7 @@
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
......@@ -710,6 +711,52 @@ struct cpu_softmax
}
};
struct cpu_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::rnn_var_sl_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto out_comp_lens = args[0].get_shape().lens();
out_comp_lens[0] = 1;
shape out_comp_s{output_shape.type(), out_comp_lens};
visit_all(result, args[0])([&](auto output, auto input) {
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = out_comp_s.multi(i);
auto b = idx[2];
if(op.direction == op::rnn_direction::reverse or idx[1] == 1)
{
idx[0] = 0;
}
else
{
idx[0] = seq_lens[b] - 1;
}
output[i] = input(idx.begin(), idx.end());
});
});
});
return result;
}
};
struct cpu_apply
{
program* prog;
......@@ -745,6 +792,8 @@ struct cpu_apply
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["softmax"] = extend_op<cpu_softmax<op::softmax>, op::softmax>();
apply_map["rnn_var_sl_last_output"] =
extend_op<cpu_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
}
void apply()
......
......@@ -67,6 +67,7 @@ add_library(migraphx_device
device/sub.cpp
device/tan.cpp
device/tanh.cpp
device/rnn_variable_seq_lens.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
......@@ -117,6 +118,7 @@ add_library(migraphx_gpu
int8_conv_pack.cpp
gemm_impl.cpp
preallocate_param.cpp
rnn_variable_seq_lens.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
......
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rnn_var_sl_shift_sequence(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl)
{
auto output_shape = result.get_shape();
int64_t max_len = output_shape.lens()[0];
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
auto out_s = make_hip_shape<3>(output_shape);
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto t = idx[0];
auto b = idx[1];
auto l = sl_data[b];
auto val = in_data[0];
val = 0;
if(t >= max_len - l)
{
auto in_idx = idx;
in_idx[0] -= (max_len - l);
val = in_data[out_s.index(in_idx)];
}
out_data[i] = val;
});
});
});
}
void rnn_var_sl_shift_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse)
{
auto output_shape = result.get_shape();
int64_t max_len = output_shape.lens()[0];
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
auto out_s = make_hip_shape<4>(output_shape);
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
gs_launch(stream, output_shape.elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto t = idx[0];
auto d = idx[1];
auto b = idx[2];
auto l = sl_data[b];
auto val = in_data[0];
val = 0;
if(t < l)
{
int offset = (d == 1 or is_reverse) ? 1 : 0;
auto in_idx = idx;
in_idx[0] += offset * (max_len - l);
val = in_data[out_s.index(in_idx)];
}
out_data[i] = val;
});
});
});
}
void rnn_var_sl_last_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse)
{
auto input_shape = arg_hs.get_shape();
auto out_comp_lens = input_shape.lens();
out_comp_lens[0] = 1;
shape out_comp_shape{input_shape.type(), out_comp_lens};
visit_all(result, arg_hs)([&](auto output, auto input) {
const auto* in_data = device_cast(input.data());
auto* out_data = device_cast(output.data());
arg_sl.visit([&](auto sl) {
const auto* sl_data = device_cast(sl.data());
auto in_s = make_hip_shape<4>(input_shape);
auto out_s = make_hip_shape<4>(out_comp_shape);
gs_launch(stream, result.get_shape().elements(), 256)([=](auto i) __device__ {
auto idx = out_s.multi(i);
auto d = idx[1];
auto b = idx[2];
auto l = sl_data[b];
if(is_reverse or d == 1)
{
idx[0] = 0;
}
else
{
idx[0] = l - 1;
}
out_data[i] = in_data[in_s.index(idx)];
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_RNN_VARIABLE_SEQ_LENS_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void rnn_var_sl_shift_sequence(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl);
void rnn_var_sl_shift_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse);
void rnn_var_sl_last_output(hipStream_t stream,
const argument& result,
const argument& arg_hs,
const argument& arg_sl,
bool is_reverse);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP
#define MIGRAPHX_GUARD_RTGLIB_RNN_VARIABLE_SEQ_LENS_HPP
#include <migraphx/shape.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_rnn_var_sl_shift_sequence
{
op::rnn_var_sl_shift_sequence op;
std::string name() const { return "gpu::rnn_var_sl_shift_sequence"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_rnn_var_sl_shift_output
{
op::rnn_var_sl_shift_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::rnn_var_sl_shift_output"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -72,6 +72,7 @@
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/prelu.hpp>
#include <migraphx/gpu/recip.hpp>
#include <migraphx/gpu/rnn_variable_seq_lens.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -184,9 +185,14 @@ struct miopen_apply
add_extend_op<hip_reduce_min, op::reduce_min>("reduce_min");
add_extend_op<hip_reduce_prod, op::reduce_prod>("reduce_prod");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_rnn_var_sl_shift_output, op::rnn_var_sl_shift_output>(
"rnn_var_sl_shift_output");
add_extend_op<hip_rnn_var_sl_shift_sequence, op::rnn_var_sl_shift_sequence>(
"rnn_var_sl_shift_sequence");
add_extend_op<hip_rnn_var_sl_last_output, op::rnn_var_sl_last_output>(
"rnn_var_sl_last_output");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_lrn_op();
add_convolution_op();
add_deconvolution_op();
......
#include <migraphx/gpu/rnn_variable_seq_lens.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/rnn_variable_seq_lens.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_rnn_var_sl_shift_output::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_shift_output::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_shift_output(ctx.get_stream().get(),
args.back(),
args.at(0),
args.at(1),
(op.direction == op::rnn_direction::reverse));
return args.back();
}
shape hip_rnn_var_sl_shift_sequence::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_shift_sequence::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_shift_sequence(ctx.get_stream().get(), args.back(), args.at(0), args.at(1));
return args.back();
}
shape hip_rnn_var_sl_last_output::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_rnn_var_sl_last_output::compute(context& ctx,
const shape&,
const std::vector<argument>& args) const
{
device::rnn_var_sl_last_output(ctx.get_stream().get(),
args.back(),
args.at(0),
args.at(1),
(op.direction == op::rnn_direction::reverse));
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -51,6 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
rewrite_pooling{},
dead_code_elimination{},
eliminate_common_subexpression{},
......
This diff is collapsed.
......@@ -2580,7 +2580,7 @@ struct test_rnn_forward : verify_program<test_rnn_forward>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2622,7 +2622,7 @@ struct test_rnn_forward10 : verify_program<test_rnn_forward10>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2663,7 +2663,7 @@ struct test_rnn_two_outputs : verify_program<test_rnn_two_outputs>
bias,
und,
ih);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -2850,7 +2850,7 @@ struct test_rnn_5args : verify_program<test_rnn_5args>
r,
bias,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2892,7 +2892,7 @@ struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2933,7 +2933,7 @@ struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -2968,7 +2968,7 @@ struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3012,7 +3012,7 @@ struct test_gru_forward_last : verify_program<test_gru_forward_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3215,7 +3215,7 @@ struct test_gru_two_outputs : verify_program<test_gru_two_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -3301,7 +3301,7 @@ struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3377,7 +3377,7 @@ struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -3616,6 +3616,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
......@@ -3624,9 +3625,9 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto len = p.add_literal(migraphx::literal(l_shape, {1, 2}));
auto ic = p.add_parameter("ic", ic_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::lstm{
......@@ -3638,11 +3639,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
w,
r,
bias,
und,
len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output, len);
return p;
}
......@@ -3801,7 +3802,7 @@ struct test_lstm_two_outputs : verify_program<test_lstm_two_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
p.add_return({hs, last_hs});
return p;
......@@ -3837,8 +3838,8 @@ struct test_lstm_three_outputs : verify_program<test_lstm_three_outputs>
seq,
w,
r);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
auto last_hs = p.add_instruction(migraphx::op::rnn_last_hs_output{}, hs);
auto last_cell = p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
p.add_return({hs, last_hs, last_cell});
return p;
......@@ -3995,7 +3996,7 @@ struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -4064,7 +4065,7 @@ struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3a
seq,
w,
r);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, hs);
return p;
}
......@@ -4115,7 +4116,7 @@ struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, output);
return p;
}
......@@ -4140,13 +4141,15 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data{3, 2};
auto sql = p.add_literal(migraphx::literal{migraphx::literal{sl_shape, sl_data}});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
......@@ -4156,7 +4159,7 @@ struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
w,
r,
bias,
und,
sql,
ih);
return p;
......@@ -4315,13 +4318,15 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
{num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
std::vector<int> sl_data(batch_size, 2);
auto sql = p.add_literal(migraphx::literal{sl_shape, sl_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}},
......@@ -4331,7 +4336,7 @@ struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_defaul
w,
r,
bias,
und,
sql,
ih);
return p;
......
......@@ -61,7 +61,7 @@ TEST_CASE(rnn_test_bidirectional)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
......@@ -103,7 +103,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
......@@ -129,7 +129,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
......@@ -153,7 +153,7 @@ TEST_CASE(rnn_test_one_direction)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
......@@ -181,7 +181,7 @@ TEST_CASE(rnn_test_one_direction)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
......@@ -225,7 +225,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
......@@ -259,7 +259,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
......@@ -296,7 +296,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
......@@ -335,7 +335,7 @@ TEST_CASE(gru_test_args)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
......@@ -367,7 +367,7 @@ TEST_CASE(gru_test_args)
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
......@@ -404,7 +404,7 @@ TEST_CASE(gru_test_args)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
......@@ -450,7 +450,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
......@@ -487,7 +487,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
......@@ -524,7 +524,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
......@@ -561,7 +561,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
......@@ -595,7 +595,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
......@@ -629,7 +629,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
......@@ -678,7 +678,7 @@ TEST_CASE(lstm_forward)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_forward.onnx");
EXPECT(p == prog);
......@@ -707,7 +707,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f3args.onnx");
EXPECT(p == prog);
......@@ -764,7 +764,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_last.onnx");
EXPECT(p == prog);
......@@ -793,7 +793,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_cell.onnx");
EXPECT(p == prog);
......@@ -823,7 +823,7 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
EXPECT(p == prog);
......@@ -854,8 +854,8 @@ TEST_CASE(lstm_forward)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
EXPECT(p == prog);
......@@ -887,8 +887,8 @@ TEST_CASE(lstm_forward)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
EXPECT(p == prog);
......@@ -921,8 +921,8 @@ TEST_CASE(lstm_forward)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
EXPECT(p == prog);
......@@ -950,6 +950,7 @@ TEST_CASE(lstm_forward_actv_func)
auto seq = p.add_parameter("seq", seq_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
// auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
......@@ -967,7 +968,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
EXPECT(p == prog);
......@@ -997,7 +998,7 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
EXPECT(p == prog);
......@@ -1028,8 +1029,8 @@ TEST_CASE(lstm_forward_actv_func)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
EXPECT(p == prog);
......@@ -1078,7 +1079,7 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
EXPECT(p == prog);
......@@ -1109,8 +1110,8 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
EXPECT(p == prog);
......@@ -1139,7 +1140,7 @@ TEST_CASE(lstm_reverse)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
EXPECT(p == prog);
......@@ -1192,7 +1193,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi.onnx");
EXPECT(p == prog);
......@@ -1225,7 +1226,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
EXPECT(p == prog);
......@@ -1259,7 +1260,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
EXPECT(p == prog);
......@@ -1294,7 +1295,7 @@ TEST_CASE(lstm_bidirectional)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
EXPECT(p == prog);
......@@ -1330,7 +1331,7 @@ TEST_CASE(lstm_bidirectional)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
EXPECT(p == prog);
......@@ -1367,7 +1368,7 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
EXPECT(p == prog);
......@@ -1417,7 +1418,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
EXPECT(p == prog);
......@@ -1451,7 +1452,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
EXPECT(p == prog);
......@@ -1486,7 +1487,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
EXPECT(p == prog);
......@@ -1522,7 +1523,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
EXPECT(p == prog);
......@@ -1559,7 +1560,7 @@ TEST_CASE(lstm_bi_actv_funcs)
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
EXPECT(p == prog);
......@@ -1592,7 +1593,7 @@ TEST_CASE(lstm_bi_actv_funcs)
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment