Commit c335de61 authored by Shiv's avatar Shiv
Browse files

Allow reshape ops between dq and quant_op

parent 5fe1b075
......@@ -858,9 +858,9 @@ auto skip_broadcasts_converts(Ms... ms)
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
auto skip_post_dq_ops(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose", "reshape"))(ms...);
}
template <class T>
......
......@@ -112,10 +112,10 @@ struct match_find_quantizable_ops
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
match::arg(0)(
match::skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(
match::skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
......
......@@ -331,10 +331,52 @@ TEST_CASE(dot_transposed)
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_transposed_broadcasted)
TEST_CASE(dot_reshaped)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_t = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1000, 1024}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_t = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1000, 1024}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_all_skip_post_dq_ops)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 10, 100}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::shape sh4{migraphx::shape::float_type, {1024}};
......@@ -346,12 +388,13 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted)
auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto d2_r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1024, 1000}}}), d2);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2_r);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
......@@ -366,10 +409,11 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted)
auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q2_r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1024, 1000}}}), q2);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2_r);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment