Commit 1807abc6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup

parent 143543e4
...@@ -479,11 +479,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -479,11 +479,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref hidden_states = prog.end(); instruction_ref hidden_states = prog.end();
instruction_ref last_output{}; instruction_ref last_output{};
long seq_len = static_cast<long>(seq->get_shape().lens()[0]); migraphx::shape seq_shape = seq->get_shape();
long hs = static_cast<long>(r->get_shape().lens()[2]); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq->get_shape().type(), migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
{seq->get_shape().lens()[1], static_cast<std::size_t>(hs)});
std::vector<int> data(s.elements(), 1); std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data}); auto l1 = prog.add_literal(migraphx::literal{s, data});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment