Commit d4594903 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed build error.

parent 1596cf1f
......@@ -12,6 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
rewrite_gru.cpp
env.cpp
generate.cpp
instruction.cpp
......
......@@ -21,7 +21,7 @@ struct rewrite_gru
void apply(program& prog) const;
private:
std::vector<instruction_ref> rnn_gru(bool is_forward,
std::vector<instruction_ref> gru_oper(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
......@@ -29,7 +29,9 @@ struct rewrite_gru
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const;
int linear_before_reset,
operation& actv_func1,
operation& actv_func2) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -33,17 +33,16 @@ void rewrite_gru::apply(program& prog) const
op::gru::gru_direction_t dicrt = gru_op.direction;
if(dicrt == op::gru::bidirectional)
{
long hs = static_cast<long>(hidden_size);
// forward weight
auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}, uw_forward});
auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uw_forward);
auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);
// reverse weight
auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}, uw_reverse});
auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uw_reverse);
auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);
......@@ -92,7 +91,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
auto ret_reverse = rnn_oper(false,
auto ret_reverse = gru_oper(false,
prog,
ins,
args[0],
......@@ -136,7 +135,7 @@ void rewrite_gru::apply(program& prog) const
}
else
{
ih = prog.add_literal(migraphx::literal{s, data});
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = gru_oper(is_forward,
......@@ -175,7 +174,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), {1});
auto l1 = prog.add_literal(migraphx::leteral{s, {1}});
auto l1 = prog.add_literal(migraphx::literal{s, {1}});
// weight matrix
std::vector<int64_t> perm{1, 0};
......@@ -199,12 +198,12 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
{
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias);
br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias);
br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
......@@ -245,7 +244,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
if(bias != prog.end())
{
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
......@@ -267,13 +266,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
}
}
ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);
auto ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto 1zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto 1ztht = prog.insert_instruction(ins, op::mul{}, 1zt, ht);
auto z1t = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht);
auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
ih = prog.insert_instruction(ins, op::add{}, 1ztht ztht1);
ih = prog.insert_instruction(ins, op::add{}, z1tht, ztht1);
final_out = ih;
if(is_forward)
......
......@@ -3,6 +3,7 @@
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -12,7 +13,10 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const
{
return {auto_contiguous{}, rewrite_rnn{}, lowering{}};
return {auto_contiguous{},
rewrite_rnn{},
rewrite_gru{},
lowering{}};
}
} // namespace cpu
......
......@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
......@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
rewrite_gru{},
dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
constant_propagate{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment