Commit c335de61 authored by Shiv's avatar Shiv
Browse files

Allow reshape ops between dq and quant_op

parent 5fe1b075
...@@ -858,9 +858,9 @@ auto skip_broadcasts_converts(Ms... ms) ...@@ -858,9 +858,9 @@ auto skip_broadcasts_converts(Ms... ms)
} }
template <class... Ms> template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms) auto skip_post_dq_ops(Ms... ms)
{ {
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...); return skip(name("broadcast", "multibroadcast", "contiguous", "transpose", "reshape"))(ms...);
} }
template <class T> template <class T>
......
...@@ -112,10 +112,10 @@ struct match_find_quantizable_ops ...@@ -112,10 +112,10 @@ struct match_find_quantizable_ops
auto matcher() const auto matcher() const
{ {
return match::name(get_quantizable_op_names())( return match::name(get_quantizable_op_names())(
match::arg(0)(match::skip_broadcasts_transposes_contiguous( match::arg(0)(
dequantizelinear_op("scale1", "zp1").bind("dq1"))), match::skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous( match::arg(1)(
dequantizelinear_op("scale2", "zp2").bind("dq2")))); match::skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
......
...@@ -331,10 +331,52 @@ TEST_CASE(dot_transposed) ...@@ -331,10 +331,52 @@ TEST_CASE(dot_transposed)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_multi_scale_transposed_broadcasted) TEST_CASE(dot_reshaped)
{ {
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}}; migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}}; migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_t = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1000, 1024}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_t = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1000, 1024}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_all_skip_post_dq_ops)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 10, 100}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}}; migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::shape sh4{migraphx::shape::float_type, {1024}}; migraphx::shape sh4{migraphx::shape::float_type, {1024}};
...@@ -346,12 +388,13 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted) ...@@ -346,12 +388,13 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted)
auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0)); auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto d2_r = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1024, 1000}}}), d2);
auto d2_t = auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2); m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2_r);
auto d2_mb = m1.add_instruction( auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb); auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
...@@ -366,10 +409,11 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted) ...@@ -366,10 +409,11 @@ TEST_CASE(dot_multi_scale_transposed_broadcasted)
auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0)); auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q2_r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1024, 1000}}}), q2);
auto q2_t = auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2_r);
auto q2_mb = m2.add_instruction( auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment