Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -13,15 +13,16 @@ struct test_batchnorm_1d : verify_program<test_batchnorm_1d> ...@@ -13,15 +13,16 @@ struct test_batchnorm_1d : verify_program<test_batchnorm_1d>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, size}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, size}};
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); mm->add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -13,15 +13,16 @@ struct test_batchnorm_1d_per_actv : verify_program<test_batchnorm_1d_per_actv> ...@@ -13,15 +13,16 @@ struct test_batchnorm_1d_per_actv : verify_program<test_batchnorm_1d_per_actv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1}}; migraphx::shape vars{migraphx::shape::float_type, {channels, d1}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction( mm->add_instruction(
migraphx::op::batch_norm_inference{ migraphx::op::batch_norm_inference{
1.0e-5, 0.96f, migraphx::op::batch_norm_inference::per_activation}, 1.0e-5, 0.96f, migraphx::op::batch_norm_inference::per_activation},
x, x,
......
...@@ -14,15 +14,16 @@ struct test_batchnorm_2d_per_actv : verify_program<test_batchnorm_2d_per_actv> ...@@ -14,15 +14,16 @@ struct test_batchnorm_2d_per_actv : verify_program<test_batchnorm_2d_per_actv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2}}; migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction( mm->add_instruction(
migraphx::op::batch_norm_inference{ migraphx::op::batch_norm_inference{
1.0e-6, 0.9f, migraphx::op::batch_norm_inference::per_activation}, 1.0e-6, 0.9f, migraphx::op::batch_norm_inference::per_activation},
x, x,
......
...@@ -15,15 +15,16 @@ struct test_batchnorm_3d : verify_program<test_batchnorm_3d> ...@@ -15,15 +15,16 @@ struct test_batchnorm_3d : verify_program<test_batchnorm_3d>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}};
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); mm->add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -15,15 +15,16 @@ struct test_batchnorm_3d_per_actv : verify_program<test_batchnorm_3d_per_actv> ...@@ -15,15 +15,16 @@ struct test_batchnorm_3d_per_actv : verify_program<test_batchnorm_3d_per_actv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, d1, d2, d3}};
migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2, d3}}; migraphx::shape vars{migraphx::shape::float_type, {channels, d1, d2, d3}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction( mm->add_instruction(
migraphx::op::batch_norm_inference{ migraphx::op::batch_norm_inference{
1.0e-6, 0.8f, migraphx::op::batch_norm_inference::per_activation}, 1.0e-6, 0.8f, migraphx::op::batch_norm_inference::per_activation},
x, x,
......
...@@ -14,15 +14,16 @@ struct test_batchnorm_inference : verify_program<test_batchnorm_inference> ...@@ -14,15 +14,16 @@ struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}};
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); mm->add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -14,15 +14,16 @@ struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2> ...@@ -14,15 +14,16 @@ struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}}; migraphx::shape s{migraphx::shape::float_type, {batches, channels, height, width}};
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); mm->add_instruction(migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -9,10 +9,11 @@ struct test_ceil : verify_program<test_ceil> ...@@ -9,10 +9,11 @@ struct test_ceil : verify_program<test_ceil>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}};
auto param = p.add_parameter("x", s); auto param = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::ceil{}, param); mm->add_instruction(migraphx::op::ceil{}, param);
return p; return p;
}; };
}; };
...@@ -9,12 +9,13 @@ struct test_clip : verify_program<test_clip> ...@@ -9,12 +9,13 @@ struct test_clip : verify_program<test_clip>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}}); auto* mm = p.get_main_module();
auto min_val = p.add_literal(0.0f); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3}});
auto max_val = p.add_literal(6.0f); auto min_val = mm->add_literal(0.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, min_val); auto max_val = mm->add_literal(6.0f);
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val); min_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, min_val);
p.add_instruction(migraphx::op::clip{}, x, min_val, max_val); max_val = mm->add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
mm->add_instruction(migraphx::op::clip{}, x, min_val, max_val);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0> ...@@ -9,14 +9,15 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {1, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_concat_axis_1 : verify_program<test_concat_axis_1> ...@@ -9,14 +9,15 @@ struct test_concat_axis_1 : verify_program<test_concat_axis_1>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -9,14 +9,15 @@ struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1> ...@@ -9,14 +9,15 @@ struct test_concat_axis_neg_1 : verify_program<test_concat_axis_neg_1>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = -1; int axis = -1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 1}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 1}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -9,15 +9,16 @@ struct test_concat_pooling : verify_program<test_concat_pooling> ...@@ -9,15 +9,16 @@ struct test_concat_pooling : verify_program<test_concat_pooling>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
auto input = auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 256, 8, 8}});
auto transpose = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input); auto transpose = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, input);
auto concat = p.add_instruction(migraphx::op::concat{3}, transpose); auto concat = mm->add_instruction(migraphx::op::concat{3}, transpose);
auto concat_t = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat); auto concat_t = mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, concat);
auto pooling = auto pooling =
p.add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t); mm->add_instruction(migraphx::op::pooling{"average", {0, 0}, {1, 1}, {8, 8}}, concat_t);
p.add_instruction(migraphx::op::relu{}, pooling); mm->add_instruction(migraphx::op::relu{}, pooling);
return p; return p;
} }
}; };
...@@ -9,18 +9,19 @@ struct test_concat_relu : verify_program<test_concat_relu> ...@@ -9,18 +9,19 @@ struct test_concat_relu : verify_program<test_concat_relu>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::float_type, {2, 2}}; migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {3, 2}}; migraphx::shape s1{migraphx::shape::float_type, {3, 2}};
migraphx::shape s2{migraphx::shape::float_type, {1, 2}}; migraphx::shape s2{migraphx::shape::float_type, {1, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
auto r0 = p.add_instruction(migraphx::op::relu{}, l0); auto r0 = mm->add_instruction(migraphx::op::relu{}, l0);
auto r1 = p.add_instruction(migraphx::op::relu{}, l1); auto r1 = mm->add_instruction(migraphx::op::relu{}, l1);
auto r2 = p.add_instruction(migraphx::op::relu{}, l2); auto r2 = mm->add_instruction(migraphx::op::relu{}, l2);
auto c0 = p.add_instruction(migraphx::op::concat{axis}, r0, r1, r2); auto c0 = mm->add_instruction(migraphx::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraphx::op::relu{}, c0); mm->add_instruction(migraphx::op::relu{}, c0);
return p; return p;
} }
}; };
...@@ -9,15 +9,16 @@ struct test_concat_transpose : verify_program<test_concat_transpose> ...@@ -9,15 +9,16 @@ struct test_concat_transpose : verify_program<test_concat_transpose>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}}; migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1); auto lp1 = mm->add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1); auto l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto l2 = p.add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -9,15 +9,16 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2> ...@@ -9,15 +9,16 @@ struct test_concat_transpose2 : verify_program<test_concat_transpose2>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}}; migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto lp2 = p.add_parameter("z", s2); auto lp2 = mm->add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2); auto l2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -9,16 +9,17 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3> ...@@ -9,16 +9,17 @@ struct test_concat_transpose3 : verify_program<test_concat_transpose3>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
int axis = 1; int axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}}; migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1); auto lp1 = mm->add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1); auto l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto lp2 = p.add_parameter("z", s2); auto lp2 = mm->add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2); auto l2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2); mm->add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p; return p;
} }
}; };
...@@ -10,9 +10,10 @@ struct test_contiguous : verify_program<test_contiguous> ...@@ -10,9 +10,10 @@ struct test_contiguous : verify_program<test_contiguous>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::op::contiguous{}, x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
...@@ -10,9 +10,10 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast> ...@@ -10,9 +10,10 @@ struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}}; migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::op::contiguous{}, x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
...@@ -10,9 +10,10 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa ...@@ -10,9 +10,10 @@ struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broa
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}}; migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}};
auto x = p.add_parameter("x", s); auto x = mm->add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x); mm->add_instruction(migraphx::op::contiguous{}, x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment