Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -9,16 +9,17 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> ...@@ -9,16 +9,17 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p; return p;
} }
......
...@@ -9,16 +9,17 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> ...@@ -9,16 +9,17 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3); mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p; return p;
} }
......
...@@ -9,12 +9,13 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> ...@@ -9,12 +9,13 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); mm->add_instruction(migraphx::op::dot{}, l1, l2);
return p; return p;
} }
......
...@@ -9,12 +9,13 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> ...@@ -9,12 +9,13 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); mm->add_instruction(migraphx::op::dot{}, l1, l2);
return p; return p;
} }
......
...@@ -9,15 +9,16 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -9,15 +9,16 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, l2); auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, l2);
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, tl2); mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, tl2);
return p; return p;
} }
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
inline void check_gpu_streams(const migraphx::program& p) inline void check_gpu_streams(const migraphx::program& p)
{ {
#ifdef HAVE_GPU #ifdef HAVE_GPU
auto races = migraphx::gpu::analyze_streams(p); const auto* mm = p.get_main_module();
auto races = migraphx::gpu::analyze_streams(*mm);
for(auto&& race : races) for(auto&& race : races)
{ {
std::cout << "FAILED: " << std::endl; std::cout << "FAILED: " << std::endl;
...@@ -23,7 +24,7 @@ inline void check_gpu_streams(const migraphx::program& p) ...@@ -23,7 +24,7 @@ inline void check_gpu_streams(const migraphx::program& p)
#endif #endif
} }
void validate_gpu(const migraphx::program& p, const migraphx::program::parameter_map& m) void validate_gpu(const migraphx::program& p, const migraphx::parameter_map& m)
{ {
check_gpu_streams(p); check_gpu_streams(p);
// Program should have an output parameter // Program should have an output parameter
......
...@@ -9,11 +9,12 @@ struct quant_conv : verify_program<quant_conv> ...@@ -9,11 +9,12 @@ struct quant_conv : verify_program<quant_conv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{}, pa, pc);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct quant_conv_default_mode : verify_program<quant_conv_default_mode> ...@@ -9,11 +9,12 @@ struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
p.add_instruction( mm->add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same}, migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa, pa,
pc); pc);
......
...@@ -9,11 +9,12 @@ struct quant_conv_padding : verify_program<quant_conv_padding> ...@@ -9,11 +9,12 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
...@@ -9,11 +9,12 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> ...@@ -9,11 +9,12 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
p.add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{1, 1}}, {{2, 2}}}, pa, pc);
return p; return p;
} }
......
...@@ -9,11 +9,12 @@ struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode> ...@@ -9,11 +9,12 @@ struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = p.add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = p.add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
p.add_instruction( mm->add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid}, migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa, pa,
pc); pc);
......
...@@ -9,14 +9,15 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> ...@@ -9,14 +9,15 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{}, l1, l2, l3); mm->add_instruction(migraphx::op::quant_dot{}, l1, l2, l3);
return p; return p;
} }
}; };
...@@ -9,15 +9,16 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> ...@@ -9,15 +9,16 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto tl1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3); mm->add_instruction(migraphx::op::quant_dot{1, 3}, tl1, l2, l3);
return p; return p;
} }
}; };
...@@ -9,15 +9,16 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> ...@@ -9,15 +9,16 @@ struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3); mm->add_instruction(migraphx::op::quant_dot{2, 3}, l1, tl2, l3);
return p; return p;
} }
}; };
...@@ -9,16 +9,17 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> ...@@ -9,16 +9,17 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto l1 = p.add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto tl1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto l2 = p.add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l2); auto tl2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l2);
auto l3 = p.add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3); mm->add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p; return p;
} }
}; };
...@@ -71,7 +71,7 @@ target_info run_verify::get_target_info(const std::string& name) const ...@@ -71,7 +71,7 @@ target_info run_verify::get_target_info(const std::string& name) const
void run_verify::validate(const migraphx::target& t, void run_verify::validate(const migraphx::target& t,
const migraphx::program& p, const migraphx::program& p,
const migraphx::program::parameter_map& m) const const migraphx::parameter_map& m) const
{ {
auto ti = get_target_info(t.name()); auto ti = get_target_info(t.name());
if(ti.validate) if(ti.validate)
...@@ -79,22 +79,20 @@ void run_verify::validate(const migraphx::target& t, ...@@ -79,22 +79,20 @@ void run_verify::validate(const migraphx::target& t,
} }
std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p, std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
migraphx::program::parameter_map inputs) const migraphx::parameter_map inputs) const
{ {
migraphx::ref::target t{}; migraphx::ref::target t{};
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
compile_check(p, t); compile_check(p, t);
return p.eval(std::move(inputs)); return p.eval(std::move(inputs));
} }
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>> run_verify::run_target(
run_verify::run_target(const migraphx::target& t, const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const
migraphx::program p,
const migraphx::program::parameter_map& inputs) const
{ {
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, (trace_target == t.name())); compile_check(p, t, (trace_target == t.name()));
migraphx::program::parameter_map m; migraphx::parameter_map m;
for(auto&& input : inputs) for(auto&& input : inputs)
{ {
m[input.first] = t.copy_to(input.second); m[input.first] = t.copy_to(input.second);
...@@ -137,7 +135,7 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -137,7 +135,7 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
} }
if(not target_names.empty()) if(not target_names.empty())
{ {
migraphx::program::parameter_map m; migraphx::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first)); m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
struct target_info struct target_info
{ {
using validation_function = using validation_function =
std::function<void(const migraphx::program& p, const migraphx::program::parameter_map& m)>; std::function<void(const migraphx::program& p, const migraphx::parameter_map& m)>;
bool parallel = true; bool parallel = true;
validation_function validate; validation_function validate;
}; };
...@@ -16,14 +16,14 @@ struct target_info ...@@ -16,14 +16,14 @@ struct target_info
struct run_verify struct run_verify
{ {
std::vector<migraphx::argument> run_ref(migraphx::program p, std::vector<migraphx::argument> run_ref(migraphx::program p,
migraphx::program::parameter_map inputs) const; migraphx::parameter_map inputs) const;
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>>
run_target(const migraphx::target& t, run_target(const migraphx::target& t,
migraphx::program p, migraphx::program p,
const migraphx::program::parameter_map& inputs) const; const migraphx::parameter_map& inputs) const;
void validate(const migraphx::target& t, void validate(const migraphx::target& t,
const migraphx::program& p, const migraphx::program& p,
const migraphx::program::parameter_map& m) const; const migraphx::parameter_map& m) const;
void verify(const std::string& name, const migraphx::program& p) const; void verify(const std::string& name, const migraphx::program& p) const;
void run(int argc, const char* argv[]) const; void run(int argc, const char* argv[]) const;
......
...@@ -9,8 +9,9 @@ struct test_abs : verify_program<test_abs> ...@@ -9,8 +9,9 @@ struct test_abs : verify_program<test_abs>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto* mm = p.get_main_module();
p.add_instruction(migraphx::op::abs{}, x); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::op::abs{}, x);
return p; return p;
} }
}; };
...@@ -9,9 +9,10 @@ struct test_acos : verify_program<test_acos> ...@@ -9,9 +9,10 @@ struct test_acos : verify_program<test_acos>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::double_type, {16}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::acos{}, x); mm->add_instruction(migraphx::op::acos{}, x);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_acosh : verify_program<test_acosh> ...@@ -9,14 +9,15 @@ struct test_acosh : verify_program<test_acosh>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = p.add_literal(1.1f); auto min_val = mm->add_literal(1.1f);
auto max_val = p.add_literal(100.0f); auto max_val = mm->add_literal(100.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, min_val); min_val = mm->add_instruction(migraphx::op::multibroadcast{{16}}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{{16}}, max_val); max_val = mm->add_instruction(migraphx::op::multibroadcast{{16}}, max_val);
auto cx = p.add_instruction(migraphx::op::clip{}, x, min_val, max_val); auto cx = mm->add_instruction(migraphx::op::clip{}, x, min_val, max_val);
p.add_instruction(migraphx::op::acosh{}, cx); mm->add_instruction(migraphx::op::acosh{}, cx);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment