Unverified Commit 3bc4a779 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into doc-standard

parents 3053fc95 6a72e8fc
...@@ -27,15 +27,17 @@ ...@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
template <migraphx::shape::type_t DType>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{DType, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{DType, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; migraphx::shape m3_shape{DType, {1, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
...@@ -46,3 +48,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -46,3 +48,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
return p; return p;
} }
}; };
template struct gemm_multi_3args_alpha0<migraphx::shape::float_type>;
template struct gemm_multi_3args_alpha0<migraphx::shape::half_type>;
template struct gemm_multi_3args_alpha0<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> template <migraphx::shape::type_t DType>
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{DType, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{DType, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; migraphx::shape m3_shape{DType, {1, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
...@@ -47,3 +48,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> ...@@ -47,3 +48,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
return p; return p;
} }
}; };
template struct gemm_multi_3args_beta0<migraphx::shape::float_type>;
template struct gemm_multi_3args_beta0<migraphx::shape::half_type>;
template struct gemm_multi_3args_beta0<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> template <migraphx::shape::type_t DType>
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{DType, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape m2_shape{DType, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}}; migraphx::shape m3_shape{DType, {2, 5}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
...@@ -47,3 +48,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> ...@@ -47,3 +48,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
return p; return p;
} }
}; };
template struct gemm_multi_3args_c25<migraphx::shape::float_type>;
template struct gemm_multi_3args_c25<migraphx::shape::half_type>;
template struct gemm_multi_3args_c25<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> template <migraphx::shape::type_t DType>
struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{DType, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape m2_shape{DType, {2, 3, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
...@@ -43,3 +44,7 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2> ...@@ -43,3 +44,7 @@ struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
return p; return p;
} }
}; };
template struct gemm_multi_dim_2<migraphx::shape::float_type>;
template struct gemm_multi_dim_2<migraphx::shape::half_type>;
template struct gemm_multi_dim_2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> template <migraphx::shape::type_t DType>
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; migraphx::shape m1_shape{DType, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; migraphx::shape m2_shape{DType, {2, 3, 3, 2}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
...@@ -43,3 +44,7 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3> ...@@ -43,3 +44,7 @@ struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
return p; return p;
} }
}; };
template struct gemm_multi_dim_2_3<migraphx::shape::float_type>;
template struct gemm_multi_dim_2_3<migraphx::shape::half_type>;
template struct gemm_multi_dim_2_3<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,14 +28,15 @@ ...@@ -28,14 +28,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct gemm_multi_transpose : verify_program<gemm_multi_transpose> template <migraphx::shape::type_t DType>
struct gemm_multi_transpose : verify_program<gemm_multi_transpose<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{DType, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; migraphx::shape m2_shape{DType, {3, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto tl2 = auto tl2 =
...@@ -47,3 +48,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -47,3 +48,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
return p; return p;
} }
}; };
template struct gemm_multi_transpose<migraphx::shape::float_type>;
template struct gemm_multi_transpose<migraphx::shape::half_type>;
template struct gemm_multi_transpose<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) ...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"}); "test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv : verify_program<quant_conv> template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p; return p;
} }
}; };
template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_1 : verify_program<quant_conv_1> template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_1d : verify_program<quant_conv_1d> template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}}; migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", migraphx::make_op("quant_convolution",
...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d> ...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p; return p;
} }
}; };
template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_2 : verify_program<quant_conv_2> template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding : verify_program<quant_conv_padding> template <migraphx::shape::type_t DType>
struct quant_conv_padding : verify_program<quant_conv_padding<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
...@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding> ...@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return p; return p;
} }
}; };
template struct quant_conv_padding<migraphx::shape::int8_type>;
template struct quant_conv_padding<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> template <migraphx::shape::type_t DType>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
...@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> ...@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return p; return p;
} }
}; };
template struct quant_conv_padding_stride<migraphx::shape::int8_type>;
template struct quant_conv_padding_stride<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p, ...@@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
{ {
migraphx::target t = migraphx::make_target("ref"); migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
compile_check(p, t, c_opts); auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
return p.eval(std::move(inputs)); return p.eval(std::move(inputs));
} }
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>>
......
This diff is collapsed.
...@@ -29,16 +29,20 @@ ...@@ -29,16 +29,20 @@
#include <cassert> #include <cassert>
struct test_contiguous : verify_program<test_contiguous> template <migraphx::shape::type_t DType>
struct test_contiguous : verify_program<test_contiguous<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{DType, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("contiguous"), x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
template struct test_contiguous<migraphx::shape::float_type>;
template struct test_contiguous<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,19 @@ ...@@ -27,17 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv : verify_program<test_conv> template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p; return p;
} }
}; };
template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,15 @@ ...@@ -27,16 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv2 : verify_program<test_conv2> template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2> ...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p; return p;
} }
}; };
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add : verify_program<test_conv_add> template <migraphx::shape::type_t DType>
struct test_conv_add : verify_program<test_conv_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add> ...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add>
return p; return p;
} }
}; };
template struct test_conv_add<migraphx::shape::float_type>;
template struct test_conv_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides> template <migraphx::shape::type_t DType>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction( auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return p; return p;
} }
}; };
template struct test_conv_add_1x1_diff_strides<migraphx::shape::float_type>;
template struct test_conv_add_1x1_diff_strides<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment