Commit 58711bbc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp changes required to make bert model to generate correct results

parent 0e8f58b7
......@@ -214,8 +214,17 @@ struct onnx_parser
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
if(!arg0->get_shape().standard())
{
arg0 = prog.add_instruction(op::contiguous{}, arg0);
}
if(!arg1->get_shape().standard())
{
arg1 = prog.add_instruction(op::contiguous{}, arg1);
}
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
......@@ -634,6 +643,10 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
if(!args[2]->get_shape().standard())
{
args[2] = prog.add_instruction(op::contiguous{}, args[2]);
}
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
......@@ -683,14 +696,30 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
if(!l0->get_shape().standard())
{
l0 = prog.add_instruction(op::contiguous{}, l0);
}
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if(l1_lens != l1_broadcasted_lens)
{
if(!l1->get_shape().standard())
{
l1 = prog.add_instruction(op::contiguous{}, l1);
}
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
if(!bl1->get_shape().standard())
{
bl1 = prog.add_instruction(op::contiguous{}, bl1);
}
if(!bl0->get_shape().standard())
{
bl0 = prog.add_instruction(op::contiguous{}, bl0);
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
......@@ -817,6 +846,10 @@ struct onnx_parser
auto&& perm_vals = attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
if(!args.front()->get_shape().standard())
{
args.front() = prog.add_instruction(migraphx::op::contiguous{}, args.front());
}
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
......@@ -1011,6 +1044,10 @@ struct onnx_parser
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
if(!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
......
......@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) {
gs_launch(stream, nelements, 256)([=](auto i) {
auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]];
output_ptr[i] = input[idx];
......
......@@ -170,7 +170,7 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
check_shapes{input_shapes}.not_broadcasted();
check_shapes{input_shapes}.standard();
return op.compute_shape(input_shapes);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment