Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -9,12 +9,13 @@ struct test_div : verify_program<test_div> ...@@ -9,12 +9,13 @@ struct test_div : verify_program<test_div>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = p.add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto diff = p.add_instruction(migraphx::op::div{}, x, y); auto diff = mm->add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, z); mm->add_instruction(migraphx::op::div{}, diff, z);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_div2 : verify_program<test_div2> ...@@ -9,14 +9,15 @@ struct test_div2 : verify_program<test_div2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}}; migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = p.add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto z = p.add_parameter("z", b); auto z = mm->add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, z); auto zb = mm->add_instruction(migraphx::op::broadcast{1, s.lens()}, z);
auto diff = p.add_instruction(migraphx::op::div{}, x, y); auto diff = mm->add_instruction(migraphx::op::div{}, x, y);
p.add_instruction(migraphx::op::div{}, diff, zb); mm->add_instruction(migraphx::op::div{}, diff, zb);
return p; return p;
} }
}; };
...@@ -9,8 +9,9 @@ struct test_elu : verify_program<test_elu> ...@@ -9,8 +9,9 @@ struct test_elu : verify_program<test_elu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto* mm = p.get_main_module();
p.add_instruction(migraphx::op::leaky_relu{1.0}, x); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::leaky_relu{1.0}, x);
return p; return p;
} }
}; };
...@@ -9,12 +9,13 @@ struct test_equal : verify_program<test_equal> ...@@ -9,12 +9,13 @@ struct test_equal : verify_program<test_equal>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto input1 = p.add_parameter("x", s); auto input1 = mm->add_parameter("x", s);
auto input2 = p.add_parameter("y", s); auto input2 = mm->add_parameter("y", s);
auto r = p.add_instruction(migraphx::op::equal{}, input1, input2); auto r = mm->add_instruction(migraphx::op::equal{}, input1, input2);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
}; };
...@@ -9,13 +9,14 @@ struct test_equal_brcst : verify_program<test_equal_brcst> ...@@ -9,13 +9,14 @@ struct test_equal_brcst : verify_program<test_equal_brcst>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {3, 3}}; migraphx::shape s0{migraphx::shape::float_type, {3, 3}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
migraphx::shape s1{migraphx::shape::float_type, {3, 1}}; migraphx::shape s1{migraphx::shape::float_type, {3, 1}};
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
auto r = p.add_instruction(migraphx::op::equal{}, l0, bl1); auto r = mm->add_instruction(migraphx::op::equal{}, l0, bl1);
p.add_return({r}); mm->add_return({r});
return p; return p;
}; };
......
...@@ -9,9 +9,10 @@ struct test_erf : verify_program<test_erf> ...@@ -9,9 +9,10 @@ struct test_erf : verify_program<test_erf>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s); auto param = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::erf{}, param); mm->add_instruction(migraphx::op::erf{}, param);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_exp : verify_program<test_exp> ...@@ -9,9 +9,10 @@ struct test_exp : verify_program<test_exp>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s)); auto x = mm->add_instruction(migraphx::op::abs{}, mm->add_parameter("x", s));
p.add_instruction(migraphx::op::exp{}, x); mm->add_instruction(migraphx::op::exp{}, x);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_floor : verify_program<test_floor> ...@@ -9,10 +9,11 @@ struct test_floor : verify_program<test_floor>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s); auto param = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::floor{}, param); mm->add_instruction(migraphx::op::floor{}, param);
return p; return p;
}; };
}; };
...@@ -10,12 +10,13 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add> ...@@ -10,12 +10,13 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s); auto p1 = mm->add_parameter("x", s);
auto p2 = p.add_parameter("y", s); auto p2 = mm->add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = mm->add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = mm->add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); mm->add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize_fp16(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
......
...@@ -10,12 +10,13 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd> ...@@ -10,12 +10,13 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3); std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = mm->add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); mm->add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize_fp16(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
}; };
......
...@@ -10,12 +10,13 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> ...@@ -10,12 +10,13 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data(2 * 3); std::vector<float> data(2 * 3);
std::iota(data.begin(), data.end(), 1.0f); std::iota(data.begin(), data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = mm->add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = mm->add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); mm->add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize_fp16(p, {"all"}); migraphx::quantize_fp16(p, {"all"});
return p; return p;
}; };
......
...@@ -10,12 +10,13 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub> ...@@ -10,12 +10,13 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto p1 = p.add_parameter("x", s); auto p1 = mm->add_parameter("x", s);
auto p2 = p.add_parameter("y", s); auto p2 = mm->add_parameter("y", s);
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = mm->add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = mm->add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); mm->add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize_fp16(p, {"sub"}); migraphx::quantize_fp16(p, {"sub"});
return p; return p;
......
...@@ -9,13 +9,14 @@ struct test_gather : verify_program<test_gather> ...@@ -9,13 +9,14 @@ struct test_gather : verify_program<test_gather>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_gather_1d_index : verify_program<test_gather_1d_index> ...@@ -9,13 +9,14 @@ struct test_gather_1d_index : verify_program<test_gather_1d_index>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1}};
std::vector<int> indices{1}; std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_gather_neg_axis : verify_program<test_gather_neg_axis> ...@@ -9,13 +9,14 @@ struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_gather_neg_indices : verify_program<test_gather_neg_indices> ...@@ -9,13 +9,14 @@ struct test_gather_neg_indices : verify_program<test_gather_neg_indices>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{-2, -1, -1, -2}; std::vector<int> indices{-2, -1, -1, -2};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_gather_scalar_index : verify_program<test_gather_scalar_index> ...@@ -9,13 +9,14 @@ struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1}; std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = -1; int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,13 +9,14 @@ struct test_gather_scalar_output : verify_program<test_gather_scalar_output> ...@@ -9,13 +9,14 @@ struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type}; migraphx::shape s_indices{migraphx::shape::int32_type};
std::vector<int> indices{1}; std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = mm->add_literal(migraphx::literal{s_indices, indices});
int axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); mm->add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
}; };
...@@ -9,19 +9,20 @@ struct test_gelu : verify_program<test_gelu> ...@@ -9,19 +9,20 @@ struct test_gelu : verify_program<test_gelu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 1, 5}; std::vector<size_t> input_lens{1, 1, 5};
auto x = p.add_parameter("x", {migraphx::shape::float_type, input_lens}); auto x = mm->add_parameter("x", {migraphx::shape::float_type, input_lens});
auto half = p.add_literal(0.5f); auto half = mm->add_literal(0.5f);
auto one = p.add_literal(1.0f); auto one = mm->add_literal(1.0f);
auto sqrt2 = p.add_literal(static_cast<float>(M_SQRT2)); auto sqrt2 = mm->add_literal(static_cast<float>(M_SQRT2));
auto half_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, half); auto half_mbcast = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, half);
auto mul_half = p.add_instruction(migraphx::op::mul{}, x, half_mbcast); auto mul_half = mm->add_instruction(migraphx::op::mul{}, x, half_mbcast);
auto sqrt2_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, sqrt2); auto sqrt2_mbcast = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, sqrt2);
auto div = p.add_instruction(migraphx::op::div{}, x, sqrt2_mbcast); auto div = mm->add_instruction(migraphx::op::div{}, x, sqrt2_mbcast);
auto erf = p.add_instruction(migraphx::op::erf{}, div); auto erf = mm->add_instruction(migraphx::op::erf{}, div);
auto one_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, one); auto one_mbcast = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, one);
auto add_one = p.add_instruction(migraphx::op::add{}, erf, one_mbcast); auto add_one = mm->add_instruction(migraphx::op::add{}, erf, one_mbcast);
p.add_instruction(migraphx::op::mul{}, mul_half, add_one); mm->add_instruction(migraphx::op::mul{}, mul_half, add_one);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_gemm : verify_program<test_gemm> ...@@ -9,9 +9,10 @@ struct test_gemm : verify_program<test_gemm>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto a = p.add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto* mm = p.get_main_module();
auto b = p.add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}});
p.add_instruction(migraphx::op::dot{}, a, b); auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}});
mm->add_instruction(migraphx::op::dot{}, a, b);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment