"docs/source/conf.py" did not exist on "bb0a870006fc08570a23d75968b4d7acb86834ff"
Unverified Commit 55182aac authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #336 from ROCmSoftwarePlatform/bugs_for_bert

Bugs fix related to bert model (batch running and quantization)
parents 7534546a 8a9cfb25
...@@ -206,6 +206,16 @@ struct onnx_parser ...@@ -206,6 +206,16 @@ struct onnx_parser
return out_lens; return out_lens;
} }
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T> template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x) instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{ {
...@@ -437,12 +447,7 @@ struct onnx_parser ...@@ -437,12 +447,7 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(!args[0]->get_shape().standard()) return prog.add_instruction(op, make_contiguous(args[0]));
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]);
} }
instruction_ref instruction_ref
...@@ -490,8 +495,9 @@ struct onnx_parser ...@@ -490,8 +495,9 @@ struct onnx_parser
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
} }
op::gather op{axis}; op::gather op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
} }
instruction_ref instruction_ref
......
...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
instruction_ref input_fp16{}; instruction_ref input_fp16{};
if(input->name() == "convert") if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{ {
input_fp16 = input->inputs().front(); input_fp16 = input->inputs().front();
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -509,6 +510,32 @@ TEST_CASE(shape_gather_test) ...@@ -509,6 +510,32 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto make_contiguous = [&p](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return p.add_instruction(migraphx::op::contiguous{}, ins);
};
auto data = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, data);
auto tr_ind = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3}}, ind);
int axis = 1;
p.add_instruction(
migraphx::op::gather{axis}, make_contiguous(tr_data), make_contiguous(tr_ind));
auto prog = migraphx::parse_onnx("transpose_gather.onnx");
EXPECT(p == prog);
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment