"src/targets/vscode:/vscode.git/clone" did not exist on "3e6a9c17162af3de7b3bad1e7bfecf2d9ace76e2"
Commit 0da78d29 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge gru_operator changes

parents d53a69f6 ce7b4b17
......@@ -60,6 +60,30 @@ struct batch_norm_inference
}
};
struct lrn
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
std::string name() const { return "lrn"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.bias, "bias"),
f(self.size, "size"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
};
struct convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
......@@ -1140,19 +1164,19 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
};
struct rnn
// indicate rnn computation direction
enum class rnn_direction
{
enum rnn_direction_t
{
forward,
reverse,
bidirectional,
};
};
struct rnn
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
std::string name() const { return "rnn"; }
......@@ -1166,7 +1190,7 @@ struct rnn
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1200,16 +1224,9 @@ struct rnn_last_output
struct gru
{
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int linear_before_reset = 0;
......@@ -1224,7 +1241,7 @@ struct gru
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1242,32 +1259,11 @@ struct gru
}
};
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm
{
enum lstm_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
lstm_direction_t direction = forward;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;
......@@ -1282,7 +1278,7 @@ struct lstm
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1300,20 +1296,6 @@ struct lstm
}
};
struct lstm_last_output
{
std::string name() const { return "lstm_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
......
......@@ -64,6 +64,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
......@@ -88,6 +89,7 @@ struct onnx_parser
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
......@@ -537,6 +539,25 @@ struct onnx_parser
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
......@@ -714,14 +735,14 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::rnn::rnn_direction_t dirct = op::rnn::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn::reverse;
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
......@@ -743,7 +764,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
......@@ -803,14 +824,14 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::gru::gru_direction_t dirct = op::gru::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::gru::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::gru::reverse;
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
......@@ -824,7 +845,7 @@ struct onnx_parser
}
// need 4 activation functions
if(dirct == op::gru::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
......@@ -895,7 +916,7 @@ struct onnx_parser
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
......@@ -922,18 +943,18 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::lstm::lstm_direction_t dirct = op::lstm::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::lstm::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::lstm::reverse;
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::lstm::forward;
dirct = op::rnn_direction::forward;
}
else
{
......@@ -951,7 +972,7 @@ struct onnx_parser
}
// need 6 activation functions for bidirectional directions
if(dirct == op::lstm::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
......@@ -1034,7 +1055,7 @@ struct onnx_parser
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
......
......@@ -45,9 +45,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional)
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -107,10 +107,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
hidden_output =
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
......@@ -119,13 +117,12 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn::forward);
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
......@@ -157,16 +154,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end())
{
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
......@@ -282,7 +278,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn::bidirectional)
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
......@@ -330,9 +326,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction;
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::gru::bidirectional)
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -402,7 +398,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dicrt == op::gru::forward);
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
......@@ -447,14 +443,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
}
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
// while loop to handle case of multiple gru_last_output operators
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output";
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
......@@ -638,7 +634,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional)
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
......@@ -689,11 +685,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction;
op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{};
instruction_ref last_cell_output{};
if(dirct == op::lstm::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
// input weight matrix
......@@ -799,7 +795,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dirct == op::lstm::forward);
bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrices
auto w = args[1];
auto r = args[2];
......@@ -1100,7 +1096,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional)
if(lstm_op.direction == op::rnn_direction::bidirectional)
{
switch(num_actv_funcs)
{
......
......@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference
}
};
struct cpu_lrn
{
op::lrn op;
std::string name() const { return "cpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1];
int height = output_shape.lens()[2];
int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size;
int radius = (op.size - 1) / 2;
par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0;
dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius);
auto end = (c + radius) > channels ? channels : (c + radius);
for(auto k = start; k < end; ++k)
{
scale += std::pow(input(b, k, h, w), 2);
}
scale *= alphaoverarea;
scale += op.bias;
scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale;
});
});
});
return result;
}
};
struct cpu_convolution
{
op::convolution op;
......@@ -681,6 +718,7 @@ struct cpu_apply
apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
......
......@@ -61,6 +61,7 @@ add_library(migraphx_gpu
elu.cpp
pad.cpp
gather.cpp
lrn.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#define MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct miopen_lrn
{
shared<lrn_descriptor> ldesc;
std::string name() const { return "gpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -23,6 +23,8 @@ using fusion_plan_descriptor = MIGRAPHX_MANAGE_PTR(miopenFusionPlanDescriptor_t,
miopenDestroyFusionPlan);
using fused_operator_args = MIGRAPHX_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs);
using lrn_descriptor = MIGRAPHX_MANAGE_PTR(miopenLRNDescriptor_t, miopenDestroyLRNDescriptor);
template <class Result, class F, class... Ts>
Result make_obj(F f, Ts... xs)
{
......@@ -89,6 +91,13 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
return p;
}
inline lrn_descriptor make_lrn(const migraphx::op::lrn& op)
{
auto ldesc = make_obj<lrn_descriptor>(&miopenCreateLRNDescriptor);
miopenSetLRNDescriptor(ldesc.get(), miopenLRNCrossChannel, op.size, op.alpha, op.beta, op.bias);
return ldesc;
}
inline activation_descriptor make_relu()
{
auto ad = make_obj<activation_descriptor>(&miopenCreateActivationDescriptor);
......
......@@ -43,6 +43,7 @@
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -99,6 +100,7 @@ struct miopen_apply
add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad");
add_lrn_op();
add_convolution_op();
add_pooling_op();
add_batch_norm_inference_op();
......@@ -159,6 +161,17 @@ struct miopen_apply
});
}
void add_lrn_op()
{
apply_map.emplace("lrn", [=](instruction_ref ins) {
auto&& op = any_cast<op::lrn>(ins->get_operator());
auto ldesc = make_lrn(op);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_lrn{std::move(ldesc)}, ins->inputs().at(0), output);
});
}
template <class T>
void add_generic_op(std::string name)
{
......
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_lrn::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenLRNForward(ctx.get_stream().get_miopen(),
ldesc.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit(),
false,
nullptr);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -732,6 +732,20 @@ TEST_CASE(leaky_relu_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(lrn_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}};
auto l = p.add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}});
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1, 5}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(imagescaler_test)
{
migraphx::program p;
......
This diff is collapsed.
......@@ -669,6 +669,18 @@ struct test_elu
}
};
struct test_relu_lrn
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = p.add_instruction(migraphx::op::relu{}, x);
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
return p;
}
};
struct test_conv_pooling
{
migraphx::program create_program() const
......@@ -1144,7 +1156,7 @@ struct test_rnn_forward
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1186,7 +1198,7 @@ struct test_rnn_forward10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1227,7 +1239,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1267,7 +1279,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1302,7 +1314,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1336,7 +1348,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1373,7 +1385,7 @@ struct test_rnn_5args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1414,7 +1426,7 @@ struct test_rnn_bidirectional
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1455,7 +1467,7 @@ struct test_rnn_bidirectional10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1493,7 +1505,7 @@ struct test_rnn_bi_3args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1534,7 +1546,7 @@ struct test_gru_forward_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1542,7 +1554,7 @@ struct test_gru_forward_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1577,7 +1589,7 @@ struct test_gru_forward_hs
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1613,7 +1625,7 @@ struct test_gru_forward_3args_und
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1648,7 +1660,7 @@ struct test_gru_forward_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1680,7 +1692,7 @@ struct test_gru_forward_seq1
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1711,7 +1723,10 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip}, seq, w, r);
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p;
}
......@@ -1746,7 +1761,7 @@ struct test_gru_forward_default_actv1
p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
......@@ -1788,7 +1803,7 @@ struct test_gru_reverse_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1796,7 +1811,7 @@ struct test_gru_reverse_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1824,7 +1839,7 @@ struct test_gru_reverse_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1864,7 +1879,7 @@ struct test_gru_bidirct_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1872,7 +1887,7 @@ struct test_gru_bidirct_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1907,7 +1922,7 @@ struct test_gru_bidirct_hs
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1943,7 +1958,7 @@ struct test_gru_bidirct_3args_und
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1978,7 +1993,7 @@ struct test_gru_bidirct_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -2010,7 +2025,7 @@ struct test_gru_bidirct_seq1
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -2041,7 +2056,10 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip}, seq, w, r);
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p;
}
......@@ -2074,9 +2092,10 @@ struct test_gru_bidirct_default_actv1
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip},
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
......@@ -2090,6 +2109,7 @@ struct test_gru_bidirct_default_actv1
int main()
{
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>();
verify_program<test_abs>();
verify_program<test_concat>();
......
......@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -658,15 +658,16 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward,
clip},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
......@@ -692,7 +693,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -700,7 +701,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
......@@ -723,12 +724,13 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs,
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -736,11 +738,21 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
......@@ -757,7 +769,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -765,7 +777,7 @@ TEST_CASE(gru_test)
und,
und,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
......@@ -789,7 +801,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -797,7 +809,7 @@ TEST_CASE(gru_test)
bias,
und,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
......@@ -823,7 +835,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -831,12 +843,21 @@ TEST_CASE(gru_test)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
......@@ -854,15 +875,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip},
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
......@@ -886,14 +907,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
......@@ -919,7 +941,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -927,7 +949,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
......@@ -953,7 +975,7 @@ TEST_CASE(gru_test)
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -961,7 +983,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
......@@ -984,14 +1006,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip},
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
......@@ -1015,14 +1038,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip},
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
......@@ -1173,4 +1197,18 @@ TEST_CASE(pad_test)
migraphx::parse_onnx("pad_test.onnx");
}
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
op.alpha = 0.0001;
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -282,10 +282,11 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -307,10 +308,11 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
......@@ -332,11 +334,12 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -358,9 +361,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
throws_shape(migraphx::op::rnn{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
......@@ -382,9 +386,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
throws_shape(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -408,7 +413,7 @@ TEST_CASE(rnn)
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -435,10 +440,11 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -462,10 +468,11 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::reverse, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
......@@ -489,11 +496,12 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -517,9 +525,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::op::gru{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
throws_shape(migraphx::op::gru{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
......@@ -543,9 +552,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
throws_shape(migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -571,7 +581,7 @@ TEST_CASE(gru)
throws_shape(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment