Unverified Commit b73427c9 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into fix_for_multiconfig_generators

parents 55e635e5 4c059fa3
...@@ -117,6 +117,19 @@ struct test_layernorm_fp16 : verify_program<test_layernorm_fp16> ...@@ -117,6 +117,19 @@ struct test_layernorm_fp16 : verify_program<test_layernorm_fp16>
} }
}; };
struct test_layernorm_fp8 : verify_program<test_layernorm_fp8>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 24, 64};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_eps : verify_program<test_layernorm_eps> struct test_layernorm_eps : verify_program<test_layernorm_eps>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <limits> #include <limits>
#include <type_traits>
template <migraphx::shape::type_t Q, typename T> template <migraphx::shape::type_t Q, typename T>
struct test_literal_limits : verify_program<test_literal_limits<Q, T>> struct test_literal_limits : verify_program<test_literal_limits<Q, T>>
...@@ -33,9 +34,13 @@ struct test_literal_limits : verify_program<test_literal_limits<Q, T>> ...@@ -33,9 +34,13 @@ struct test_literal_limits : verify_program<test_literal_limits<Q, T>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input_s = migraphx::shape(Q, {3, 1}); auto input_s = migraphx::shape(Q, {3, 1});
auto infinity_val = std::numeric_limits<T>::infinity(); T infinity_val{0};
if constexpr(std::numeric_limits<T>::has_infinity and std::is_floating_point<T>{})
{
infinity_val = std::numeric_limits<T>::infinity();
}
std::vector<T> s_data{ std::vector<T> s_data{
infinity_val, static_cast<T>(-infinity_val), std::numeric_limits<T>::quiet_NaN()}; infinity_val, static_cast<T>(-infinity_val), std::numeric_limits<T>::quiet_NaN()};
...@@ -52,3 +57,4 @@ template struct test_literal_limits<migraphx::shape::double_type, double>; ...@@ -52,3 +57,4 @@ template struct test_literal_limits<migraphx::shape::double_type, double>;
template struct test_literal_limits<migraphx::shape::half_type, migraphx::half>; template struct test_literal_limits<migraphx::shape::half_type, migraphx::half>;
template struct test_literal_limits<migraphx::shape::int32_type, int32_t>; template struct test_literal_limits<migraphx::shape::int32_type, int32_t>;
template struct test_literal_limits<migraphx::shape::int8_type, int8_t>; template struct test_literal_limits<migraphx::shape::int8_type, int8_t>;
template struct test_literal_limits<migraphx::shape::fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_log : verify_program<test_log> template <migraphx::shape::type_t DType>
struct test_log : verify_program<test_log<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{DType, {6}};
auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s));
mm->add_instruction(migraphx::make_op("log"), x); mm->add_instruction(migraphx::make_op("log"), x);
return p; return p;
} }
}; };
template struct test_log<migraphx::shape::float_type>;
template struct test_log<migraphx::shape::half_type>;
template struct test_log<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>; ...@@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>; template struct test_logsoftmax<3, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -46,7 +46,9 @@ struct test_min_max : verify_program<test_min_max<Op, T>> ...@@ -46,7 +46,9 @@ struct test_min_max : verify_program<test_min_max<Op, T>>
template struct test_min_max<migraphx::op::max, migraphx::shape::float_type>; template struct test_min_max<migraphx::op::max, migraphx::shape::float_type>;
template struct test_min_max<migraphx::op::max, migraphx::shape::half_type>; template struct test_min_max<migraphx::op::max, migraphx::shape::half_type>;
template struct test_min_max<migraphx::op::max, migraphx::shape::double_type>; template struct test_min_max<migraphx::op::max, migraphx::shape::double_type>;
template struct test_min_max<migraphx::op::max, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::float_type>; template struct test_min_max<migraphx::op::min, migraphx::shape::float_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::half_type>; template struct test_min_max<migraphx::op::min, migraphx::shape::half_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::double_type>; template struct test_min_max<migraphx::op::min, migraphx::shape::double_type>;
template struct test_min_max<migraphx::op::min, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,17 @@ ...@@ -27,17 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_mul_dot_a : verify_program<test_mul_dot_a> template <migraphx::shape::type_t DType>
struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; migraphx::shape bs{DType, {2, 32, 128}};
auto a = mm->add_parameter("input", as); auto a = mm->add_parameter("input", as);
auto lit = auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 1, 32}}));
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = mm->add_instruction( auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit); migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb); auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb);
...@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a> ...@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a>
return p; return p;
} }
}; };
template struct test_mul_dot_a<migraphx::shape::float_type>;
template struct test_mul_dot_a<migraphx::shape::half_type>;
template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,18 @@ ...@@ -27,17 +27,18 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_mul_dot_b : verify_program<test_mul_dot_b> template <migraphx::shape::type_t DType>
struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; migraphx::shape bs{DType, {2, 32, 128}};
auto b = mm->add_parameter("input", bs); auto b = mm->add_parameter("input", bs);
auto lit = auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 32, 1}}));
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb = mm->add_instruction( auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit); migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb); auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb);
...@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b> ...@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b>
return p; return p;
} }
}; };
template struct test_mul_dot_b<migraphx::shape::float_type>;
template struct test_mul_dot_b<migraphx::shape::half_type>;
template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_multinomial : verify_program<test_multinomial> template <migraphx::shape::type_t DType>
struct test_multinomial : verify_program<test_multinomial<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial> ...@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial>
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(batch_size * sample_size); std::vector<float> rand_samples(batch_size * sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape rs{DType, {batch_size, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}}; migraphx::shape s{DType, {batch_size, 5}};
auto input = mm->add_parameter("input", s); auto input = mm->add_parameter("input", s);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
...@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial> ...@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial>
return p; return p;
} }
}; };
template struct test_multinomial<migraphx::shape::float_type>;
template struct test_multinomial<migraphx::shape::half_type>;
// This fails, need to figure out why
// template struct test_multinomial<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/float8.hpp>
template <class T> template <class T>
struct test_nearbyint : verify_program<test_nearbyint<T>> struct test_nearbyint : verify_program<test_nearbyint<T>>
...@@ -45,3 +46,4 @@ struct test_nearbyint : verify_program<test_nearbyint<T>> ...@@ -45,3 +46,4 @@ struct test_nearbyint : verify_program<test_nearbyint<T>>
template struct test_nearbyint<migraphx::half>; template struct test_nearbyint<migraphx::half>;
template struct test_nearbyint<float>; template struct test_nearbyint<float>;
template struct test_nearbyint<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,13 +27,14 @@ ...@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_nonzero : verify_program<test_nonzero> template <migraphx::shape::type_t DType>
struct test_nonzero : verify_program<test_nonzero<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{DType, {2, 3, 4, 5}};
auto x = mm->add_parameter("data", s); auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r}); mm->add_return({r});
...@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero> ...@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero>
return p; return p;
} }
}; };
template struct test_nonzero<migraphx::shape::float_type>;
template struct test_nonzero<migraphx::shape::half_type>;
template struct test_nonzero<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,13 +27,14 @@ ...@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_pad : verify_program<test_pad> template <migraphx::shape::type_t DType>
struct test_pad : verify_program<test_pad<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}}; migraphx::shape s0{DType, {1, 96, 165, 165}};
std::vector<int64_t> pads0 = {0, 0, 0, 0, 0, 0, 1, 1}; std::vector<int64_t> pads0 = {0, 0, 0, 0, 0, 0, 1, 1};
std::vector<int64_t> pads1 = {0, 0, 0, 0, 1, 1, 1, 1}; std::vector<int64_t> pads1 = {0, 0, 0, 0, 1, 1, 1, 1};
std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; std::vector<int64_t> pads2 = {1, 1, 1, 1, 0, 0, 0, 0};
...@@ -46,3 +47,8 @@ struct test_pad : verify_program<test_pad> ...@@ -46,3 +47,8 @@ struct test_pad : verify_program<test_pad>
return p; return p;
} }
}; };
template struct test_pad<migraphx::shape::int32_type>;
template struct test_pad<migraphx::shape::float_type>;
template struct test_pad<migraphx::shape::half_type>;
template struct test_pad<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,13 +27,15 @@ ...@@ -27,13 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_pow : verify_program<test_pow> template <typename CType>
struct test_pow : verify_program<test_pow<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{migraphx::shape::float_type, {6}}; auto* mm = p.get_main_module();
migraphx::shape s{dtype, {6}};
std::vector<float> vec_e(s.elements(), 2.0f); std::vector<float> vec_e(s.elements(), 2.0f);
auto b = mm->add_parameter("x", s); auto b = mm->add_parameter("x", s);
auto e = mm->add_literal(migraphx::literal(s, vec_e)); auto e = mm->add_literal(migraphx::literal(s, vec_e));
...@@ -41,3 +43,6 @@ struct test_pow : verify_program<test_pow> ...@@ -41,3 +43,6 @@ struct test_pow : verify_program<test_pow>
return p; return p;
} }
}; };
template struct test_pow<float>;
template struct test_pow<migraphx::half>;
template struct test_pow<migraphx::fp8::fp8e4m3fnuz>;
...@@ -23,16 +23,18 @@ ...@@ -23,16 +23,18 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small> template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1}}; migraphx::shape s{DType, {1}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto xb = auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x);
...@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm ...@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
} }
}; };
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large> template struct test_prefix_scan_sum_2d_small<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::fp8e4m3fnuz_type>;
template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 1000}}; migraphx::shape s{DType, {3, 1000}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x);
return p; return p;
} }
}; };
template struct test_prefix_scan_sum_2d_large<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,16 @@ ...@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
struct test_reduce_add : verify_program<test_reduce_add> template <migraphx::shape::type_t DType>
struct test_reduce_add : verify_program<test_reduce_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 1000, 2, 2}}; migraphx::shape s{DType, {4, 1000, 2, 2}};
migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}}; migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce_mean = auto reduce_mean =
...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add> ...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add>
return p; return p;
}; };
}; };
template struct test_reduce_add<migraphx::shape::float_type>;
template struct test_reduce_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,14 +28,14 @@ ...@@ -28,14 +28,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> template <migraphx::shape::type_t DType>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation( auto s = migraphx::shape::from_permutation(DType, {4, 256, 2, 2}, {0, 2, 3, 1});
migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x); auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x);
auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce); auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce);
...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> ...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc>
return p; return p;
}; };
}; };
template struct test_reduce_mean_nhwc<migraphx::shape::float_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::half_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap ...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_max,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_mean,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_min,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_prod,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum,
1,
migraphx::shape::fp8e4m3fnuz_type>;
struct test_reduce_mean_1 : verify_program<test_reduce_mean_1> struct test_reduce_mean_1 : verify_program<test_reduce_mean_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -46,13 +46,34 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> ...@@ -46,13 +46,34 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
}; };
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
3,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod,
-2,
migraphx::shape::fp8e4m3fnuz_type>;
...@@ -26,16 +26,21 @@ ...@@ -26,16 +26,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_reverse : verify_program<test_reverse> template <migraphx::shape::type_t DType>
struct test_reverse : verify_program<test_reverse<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 16}}; migraphx::shape s{DType, {4, 16}};
auto a0 = mm->add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
std::vector<int64_t> axis = {0}; std::vector<int64_t> axis = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0); mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0);
return p; return p;
} }
}; };
template struct test_reverse<migraphx::shape::float_type>;
template struct test_reverse<migraphx::shape::half_type>;
template struct test_reverse<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -31,7 +31,8 @@ ...@@ -31,7 +31,8 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1> template <migraphx::shape::type_t DType>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1> ...@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape{DType, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape w_shape{DType, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{DType, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{DType, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}}; migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{DType, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape); auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape); auto w = mm->add_parameter("w", w_shape);
...@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1> ...@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
} }
std::string section() const { return "rnn"; } std::string section() const { return "rnn"; }
}; };
template struct test_rnn_sql_1<migraphx::shape::float_type>;
template struct test_rnn_sql_1<migraphx::shape::half_type>;
template struct test_rnn_sql_1<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_roialign : verify_program<test_roialign> template <migraphx::shape::type_t DType>
struct test_roialign : verify_program<test_roialign<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; migraphx::shape x_s{DType, {5, 4, 10, 10}};
migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; migraphx::shape roi_s{DType, {5, 4}};
migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; migraphx::shape ind_s{migraphx::shape::int64_type, {5}};
std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1}; std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1};
...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign> ...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign>
auto roi = mm->add_parameter("roi", roi_s); auto roi = mm->add_parameter("roi", roi_s);
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r = mm->add_instruction(migraphx::make_op("roialign", auto r = mm->add_instruction(migraphx::make_op("roialign",
{{"spatial_scale", 1.0}, {{"spatial_scale", 1.0},
{"output_height", 5}, {"output_height", 5},
{"output_width", 5}, {"output_width", 5},
{"sampling_ratio", 2}}), {"sampling_ratio", 2}}),
x, x,
roi, roi,
ind); ind);
...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign> ...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign>
return p; return p;
} }
}; };
template struct test_roialign<migraphx::shape::float_type>;
template struct test_roialign<migraphx::shape::half_type>;
template struct test_roialign<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment