Commit d4594903 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed build error.

parent 1596cf1f
...@@ -12,6 +12,7 @@ add_library(migraphx ...@@ -12,6 +12,7 @@ add_library(migraphx
eliminate_concat.cpp eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
rewrite_gru.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
...@@ -21,7 +21,7 @@ struct rewrite_gru ...@@ -21,7 +21,7 @@ struct rewrite_gru
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> rnn_gru(bool is_forward, std::vector<instruction_ref> gru_oper(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
...@@ -29,7 +29,9 @@ struct rewrite_gru ...@@ -29,7 +29,9 @@ struct rewrite_gru
instruction_ref wh, instruction_ref wh,
instruction_ref ih, instruction_ref ih,
instruction_ref bias, instruction_ref bias,
operation& actv_func) const; int linear_before_reset,
operation& actv_func1,
operation& actv_func2) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,17 +33,16 @@ void rewrite_gru::apply(program& prog) const ...@@ -33,17 +33,16 @@ void rewrite_gru::apply(program& prog) const
op::gru::gru_direction_t dicrt = gru_op.direction; op::gru::gru_direction_t dicrt = gru_op.direction;
if(dicrt == op::gru::bidirectional) if(dicrt == op::gru::bidirectional)
{ {
long hs = static_cast<long>(hidden_size);
// forward weight // forward weight
auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}, uw_forward}); auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uw_forward);
auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward); auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);
// reverse weight // reverse weight
auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}, uw_reverse}); auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uw_reverse);
auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse); auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);
...@@ -92,7 +91,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -92,7 +91,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
auto ret_reverse = rnn_oper(false, auto ret_reverse = gru_oper(false,
prog, prog,
ins, ins,
args[0], args[0],
...@@ -136,7 +135,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -136,7 +135,7 @@ void rewrite_gru::apply(program& prog) const
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{s, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = gru_oper(is_forward, auto ret = gru_oper(is_forward,
...@@ -175,7 +174,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -175,7 +174,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), {1}); migraphx::shape s(input->get_shape().type(), {1});
auto l1 = prog.add_literal(migraphx::leteral{s, {1}}); auto l1 = prog.add_literal(migraphx::literal{s, {1}});
// weight matrix // weight matrix
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
...@@ -199,12 +198,12 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -199,12 +198,12 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
{ {
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias);
br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh); br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias);
br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh); br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
...@@ -245,7 +244,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -245,7 +244,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih); auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh); auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
...@@ -267,13 +266,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -267,13 +266,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
} }
} }
ht = prog.insert_instruction(ins, actv_func2, xwhh_rt); auto ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto 1zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto z1t = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto 1ztht = prog.insert_instruction(ins, op::mul{}, 1zt, ht); auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht);
auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih); auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
ih = prog.insert_instruction(ins, op::add{}, 1ztht ztht1); ih = prog.insert_instruction(ins, op::add{}, z1tht, ztht1);
final_out = ih; final_out = ih;
if(is_forward) if(is_forward)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,7 +13,10 @@ std::string target::name() const { return "cpu"; } ...@@ -12,7 +13,10 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const std::vector<pass> target::get_passes(migraphx::context&) const
{ {
return {auto_contiguous{}, rewrite_rnn{}, lowering{}}; return {auto_contiguous{},
rewrite_rnn{},
rewrite_gru{},
lowering{}};
} }
} // namespace cpu } // namespace cpu
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gru{},
dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
constant_propagate{}, constant_propagate{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment