Commit 7aee6388 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 225cd3a4
...@@ -166,7 +166,7 @@ static void remove_contiguous_noops(const std::string& op_name, module& m) ...@@ -166,7 +166,7 @@ static void remove_contiguous_noops(const std::string& op_name, module& m)
{ {
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if (ins->name() != op_name) if(ins->name() != op_name)
continue; continue;
if(ins->inputs().front()->get_shape() != ins->get_shape()) if(ins->inputs().front()->get_shape() != ins->get_shape())
continue; continue;
......
...@@ -149,8 +149,8 @@ TEST_CASE(two_transpose_gather) ...@@ -149,8 +149,8 @@ TEST_CASE(two_transpose_gather)
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td); auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd); auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto csd = m2.add_instruction(migraphx::make_op("contiguous"), sd); auto csd = m2.add_instruction(migraphx::make_op("contiguous"), sd);
auto bd = auto bd = m2.add_instruction(
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), csd); migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), csd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd); auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind); auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
...@@ -177,9 +177,9 @@ TEST_CASE(standard_reshape) ...@@ -177,9 +177,9 @@ TEST_CASE(standard_reshape)
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
// extra contiguous coming from reshape logic which has "requires_std_shape" attribute // extra contiguous coming from reshape logic which has "requires_std_shape" attribute
auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca); auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb); auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r); auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr}); m2.add_return({cr});
} }
......
...@@ -108,7 +108,7 @@ TEST_CASE(quant_dot) ...@@ -108,7 +108,7 @@ TEST_CASE(quant_dot)
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta); migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc); auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
...@@ -179,8 +179,8 @@ TEST_CASE(quant_dot_trans) ...@@ -179,8 +179,8 @@ TEST_CASE(quant_dot_trans)
auto tl1_alpha_int32 = auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(
"hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto tl1_alpha_int8 = auto tl1_alpha_int8 =
m.add_instruction(make_precompile_op(migraphx::make_op( m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", tl1->get_shape().type()}})), "convert", {{"target_type", tl1->get_shape().type()}})),
...@@ -291,7 +291,7 @@ TEST_CASE(quant_dot_pad) ...@@ -291,7 +291,7 @@ TEST_CASE(quant_dot_pad)
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
auto mul_alloc = m.add_instruction( auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc); auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output); auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add}); m.add_return({gemm_add});
return m; return m;
...@@ -345,14 +345,14 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -345,14 +345,14 @@ TEST_CASE(quant_dot_trans_pad)
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
migraphx::instruction_ref ptb{}; migraphx::instruction_ref ptb{};
if(int8_x4) if(int8_x4)
{ {
ptb = m.add_instruction( ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
} }
auto pb = tl2; auto pb = tl2;
if(int8_x4) if(int8_x4)
{ {
pb = m.add_instruction( pb = m.add_instruction(
...@@ -381,8 +381,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -381,8 +381,8 @@ TEST_CASE(quant_dot_trans_pad)
auto tl1_alpha_int32 = auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc); m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8 // convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op( auto tl1_alpha_int8_alloc = m.add_instruction(
"hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
migraphx::instruction_ref pta{}; migraphx::instruction_ref pta{};
if(int8_x4) if(int8_x4)
...@@ -391,11 +391,10 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -391,11 +391,10 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
} }
auto tl1_alpha_int8 = auto tl1_alpha_int8 = m.add_instruction(
m.add_instruction(make_precompile_op(migraphx::make_op( make_precompile_op(migraphx::make_op("convert", {{"target_type", ts1.type()}})),
"convert", {{"target_type", ts1.type()}})), tl1_alpha_int32,
tl1_alpha_int32, tl1_alpha_int8_alloc);
tl1_alpha_int8_alloc);
auto pa = tl1_alpha_int8; auto pa = tl1_alpha_int8;
if(int8_x4) if(int8_x4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment