Commit 377021cd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 343c4cac
...@@ -282,9 +282,9 @@ TEST_CASE(dot_float) ...@@ -282,9 +282,9 @@ TEST_CASE(dot_float)
std::vector<float> vfa(sa.elements(), 0.1f); std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = p.add_literal(migraphx::literal(sa, vfa)); auto fa = p.add_literal(migraphx::literal(sa, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa); auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa);
auto ra = p.add_instruction(migraphx::op::round{}, ma); auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
auto insert_loc = std::next(pb); auto insert_loc = std::next(pb);
...@@ -297,15 +297,16 @@ TEST_CASE(dot_float) ...@@ -297,15 +297,16 @@ TEST_CASE(dot_float)
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
// quantize parameter c to int32 type // quantize parameter c to int32 type
auto qc = p.insert_instruction(std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc); auto qc = p.insert_instruction(
std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc);
auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb); auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb);
auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f); std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha)); auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot); auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot);
std::vector<float> v_beta(pc->get_shape().elements(), 1.5f); std::vector<float> v_beta(pc->get_shape().elements(), 1.5f);
auto beta = p.add_literal(migraphx::literal(pc->get_shape(), v_beta)); auto beta = p.add_literal(migraphx::literal(pc->get_shape(), v_beta));
auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, pc); auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, pc);
p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c); p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c);
...@@ -345,15 +346,16 @@ TEST_CASE(dot_double_2args) ...@@ -345,15 +346,16 @@ TEST_CASE(dot_double_2args)
// quantize parameter a to int8 type, multiply the scale // quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f); std::vector<float> vfa(sa.elements(), 0.1f);
auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); auto fpa = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa); auto ma = p.add_instruction(migraphx::op::mul{}, fa, fpa);
auto ra = p.add_instruction(migraphx::op::round{}, ma); auto ra = p.add_instruction(migraphx::op::round{}, ma);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra); auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca); auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type // quantize parameter b to int8 type
auto insert_loc = std::next(pb); auto insert_loc = std::next(pb);
auto fpb = p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); auto fpb = p.insert_instruction(
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
std::vector<float> vfb(sb.elements(), 0.1f); std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, fpb);
...@@ -366,15 +368,14 @@ TEST_CASE(dot_double_2args) ...@@ -366,15 +368,14 @@ TEST_CASE(dot_double_2args)
auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); auto fdot = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f); std::vector<float> v_alpha(fdot->get_shape().elements(), 200.0f);
auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha)); auto new_alpha = p.add_literal(migraphx::literal(fdot->get_shape(), v_alpha));
auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot); auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fdot);
p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, alpha_ab); p.add_instruction(migraphx::op::convert{migraphx::shape::double_type}, alpha_ab);
return p; return p;
}; };
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, {"dot"}, quant_params);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -474,7 +475,7 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -474,7 +475,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
std::vector<float> vfa(sa.elements(), 0.1f); std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a); auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a);
// add the shift // add the shift
std::vector<float> vsa(sa.elements(), 1.0f); std::vector<float> vsa(sa.elements(), 1.0f);
...@@ -488,7 +489,8 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -488,7 +489,8 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto insert_loc = std::next(pb); auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f); std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); auto conv_b = p.insert_instruction(
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
...@@ -501,14 +503,14 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -501,14 +503,14 @@ TEST_CASE(dot_large_alpha_beta_int32)
}; };
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, {"dot"}, quant_params);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
} }
TEST_CASE(dot_int32) TEST_CASE(dot_int32)
{ {
auto create_program = [] { auto create_program = [] {
...@@ -537,7 +539,7 @@ TEST_CASE(dot_int32) ...@@ -537,7 +539,7 @@ TEST_CASE(dot_int32)
std::vector<float> vfa(sa.elements(), 0.1f); std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa)); auto fa = p.add_literal(migraphx::literal({migraphx::shape::float_type, sa.lens()}, vfa));
auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa); auto conv_a = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pa);
auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a); auto ma = p.add_instruction(migraphx::op::mul{}, fa, conv_a);
// add the shift // add the shift
std::vector<float> vsa(sa.elements(), 1.0f); std::vector<float> vsa(sa.elements(), 1.0f);
...@@ -551,7 +553,8 @@ TEST_CASE(dot_int32) ...@@ -551,7 +553,8 @@ TEST_CASE(dot_int32)
auto insert_loc = std::next(pb); auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f); std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb)); auto fb = p.add_literal(migraphx::literal({migraphx::shape::float_type, sb.lens()}, vfb));
auto conv_b = p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb); auto conv_b = p.insert_instruction(
insert_loc, migraphx::op::convert{migraphx::shape::float_type}, pb);
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b); auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, conv_b);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb); auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb); auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
...@@ -559,22 +562,23 @@ TEST_CASE(dot_int32) ...@@ -559,22 +562,23 @@ TEST_CASE(dot_int32)
p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb); p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb); auto qdot = p.add_instruction(migraphx::op::quant_dot{1, 0}, qa, qb);
auto fr = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot); auto fr = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f); std::vector<float> v_alpha(fr->get_shape().elements(), 20.0f);
auto new_alpha = p.add_literal(migraphx::literal(fr->get_shape(), v_alpha)); auto new_alpha = p.add_literal(migraphx::literal(fr->get_shape(), v_alpha));
auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fr); auto alpha_ab = p.add_instruction(migraphx::op::mul{}, new_alpha, fr);
auto fc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pc); auto fc = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, pc);
std::vector<float> v_beta(fc->get_shape().elements(), 5.5f); std::vector<float> v_beta(fc->get_shape().elements(), 5.5f);
auto beta = p.add_literal(migraphx::literal(fc->get_shape(), v_beta)); auto beta = p.add_literal(migraphx::literal(fc->get_shape(), v_beta));
auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, fc); auto beta_c = p.add_instruction(migraphx::op::mul{}, beta, fc);
auto f_res = p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c); auto f_res = p.add_instruction(migraphx::op::add{}, alpha_ab, beta_c);
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res); p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, f_res);
return p; return p;
}; };
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, {"dot"}, quant_params);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment