Unverified Commit 10d072b9 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Small bugs related to rnn (#499)



* code changes for small bugs

* clang format

* remove standard shape requirement for transpose.

* add a unit test

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent ccf491a4
...@@ -83,6 +83,8 @@ void verify_instructions(const program& prog, double tolerance) ...@@ -83,6 +83,8 @@ void verify_instructions(const program& prog, double tolerance)
continue; continue;
if(ins.name() == "reshape") if(ins.name() == "reshape")
continue; continue;
if(ins.name() == "undefined")
continue;
program p; program p;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
for(auto&& arg : ins.inputs()) for(auto&& arg : ins.inputs())
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/op/sub.hpp> #include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
...@@ -240,7 +241,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -240,7 +241,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
if(bias != prog.end()) if(bias != prog.end())
...@@ -536,7 +538,8 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -536,7 +538,8 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw); auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr); auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
...@@ -949,7 +952,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -949,7 +952,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw); auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
......
...@@ -59,6 +59,12 @@ struct rocblas_gemm ...@@ -59,6 +59,12 @@ struct rocblas_gemm
auto dim_0 = strides.size() - 2; auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]); auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0); std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
"} are transposed!");
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) { if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size); return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end()) }) != batch.end())
......
...@@ -1716,6 +1716,25 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -1716,6 +1716,25 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
} }
}; };
struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, l2);
float alpha = 1.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, tl2);
return p;
}
};
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment