"official/projects/vit/modeling/vit.py" did not exist on "4a0551a85b3b6d99c1470df5c8dee1d9c2ffe248"
Commit 03f0e278 authored by charlie's avatar charlie
Browse files

Fix parsing and add test

parent 22012c6d
......@@ -113,15 +113,19 @@ struct parse_gemm : op_parser<parse_gemm>
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
}
if(not float_equal(beta, 1.0f))
{
auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_c->get_shape().type() != dot_type)
c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_c);
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
}
return info.add_instruction(make_op("add"), dot_ins, beta_c);
return info.add_instruction(make_op("add"), dot_ins, c_arg);
}
}
return dot_ins;
......
......@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@onnx_test()
def gemm_dyn_C_error():
def gemm_dyn_bias_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
......
......@@ -2278,6 +2278,26 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment