Unverified Commit 3bc4a779 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into doc-standard

parents 3053fc95 6a72e8fc
......@@ -27,17 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_dot_a : verify_program<test_mul_dot_a>
template <migraphx::shape::type_t DType>
struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{DType, {2, 32, 128}};
auto a = mm->add_parameter("input", as);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 1, 32}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb);
......@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a>
return p;
}
};
template struct test_mul_dot_a<migraphx::shape::float_type>;
template struct test_mul_dot_a<migraphx::shape::half_type>;
template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,17 +27,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_mul_dot_b : verify_program<test_mul_dot_b>
template <migraphx::shape::type_t DType>
struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{DType, {2, 32, 128}};
auto b = mm->add_parameter("input", bs);
auto lit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 32, 1}}));
auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb);
......@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b>
return p;
}
};
template struct test_mul_dot_b<migraphx::shape::float_type>;
template struct test_mul_dot_b<migraphx::shape::half_type>;
template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_multinomial : verify_program<test_multinomial>
template <migraphx::shape::type_t DType>
struct test_multinomial : verify_program<test_multinomial<DType>>
{
migraphx::program create_program() const
{
......@@ -40,10 +41,10 @@ struct test_multinomial : verify_program<test_multinomial>
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(batch_size * sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}};
migraphx::shape rs{DType, {batch_size, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}};
migraphx::shape s{DType, {batch_size, 5}};
auto input = mm->add_parameter("input", s);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
......@@ -58,3 +59,8 @@ struct test_multinomial : verify_program<test_multinomial>
return p;
}
};
template struct test_multinomial<migraphx::shape::float_type>;
template struct test_multinomial<migraphx::shape::half_type>;
// This fails, need to figure out why
// template struct test_multinomial<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_nonzero : verify_program<test_nonzero>
template <migraphx::shape::type_t DType>
struct test_nonzero : verify_program<test_nonzero<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape s{DType, {2, 3, 4, 5}};
auto x = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("nonzero"), x);
mm->add_return({r});
......@@ -41,3 +42,7 @@ struct test_nonzero : verify_program<test_nonzero>
return p;
}
};
template struct test_nonzero<migraphx::shape::float_type>;
template struct test_nonzero<migraphx::shape::half_type>;
template struct test_nonzero<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -23,16 +23,18 @@
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small>
template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_small<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1}};
migraphx::shape s{DType, {1}};
auto x = mm->add_parameter("x", s);
auto xb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x);
......@@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program<test_prefix_scan_sum_2d_sm
}
};
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large>
template struct test_prefix_scan_sum_2d_small<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_small<migraphx::shape::fp8e4m3fnuz_type>;
template <migraphx::shape::type_t DType>
struct test_prefix_scan_sum_2d_large : verify_program<test_prefix_scan_sum_2d_large<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 1000}};
migraphx::shape s{DType, {3, 1000}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x);
return p;
}
};
template struct test_prefix_scan_sum_2d_large<migraphx::shape::float_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::half_type>;
template struct test_prefix_scan_sum_2d_large<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
};
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
......@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh
template struct test_reduce_op_small<migraphx::op::reduce_sum,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
3,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean,
2,
migraphx::shape::fp8e4m3fnuz_type>;
......
......@@ -26,16 +26,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_reverse : verify_program<test_reverse>
template <migraphx::shape::type_t DType>
struct test_reverse : verify_program<test_reverse<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 16}};
migraphx::shape s{DType, {4, 16}};
auto a0 = mm->add_parameter("data", s);
std::vector<int64_t> axis = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0);
return p;
}
};
template struct test_reverse<migraphx::shape::float_type>;
template struct test_reverse<migraphx::shape::half_type>;
template struct test_reverse<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -31,7 +31,8 @@
#include <migraphx/op/common.hpp>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
template <migraphx::shape::type_t DType>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1<DType>>
{
migraphx::program create_program() const
{
......@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape in_shape{DType, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{DType, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{DType, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{DType, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ih_shape{DType, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
......@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
}
std::string section() const { return "rnn"; }
};
template struct test_rnn_sql_1<migraphx::shape::float_type>;
template struct test_rnn_sql_1<migraphx::shape::half_type>;
template struct test_rnn_sql_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter0 : verify_program<test_scatter0>
template <migraphx::shape::type_t DType>
struct test_scatter0 : verify_program<test_scatter0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape sd{DType, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
migraphx::shape su{DType, {2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
......@@ -47,3 +48,7 @@ struct test_scatter0 : verify_program<test_scatter0>
return p;
}
};
template struct test_scatter0<migraphx::shape::float_type>;
template struct test_scatter0<migraphx::shape::half_type>;
template struct test_scatter0<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_0 : verify_program<test_topk_0>
template <migraphx::shape::type_t DType>
struct test_topk_0 : verify_program<test_topk_0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
migraphx::shape s{DType, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 1}}), data);
......@@ -43,3 +44,7 @@ struct test_topk_0 : verify_program<test_topk_0>
return p;
}
};
template struct test_topk_0<migraphx::shape::float_type>;
template struct test_topk_0<migraphx::shape::half_type>;
template struct test_topk_0<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
migraphx::shape m1_shape{DType, {2, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
migraphx::shape m3_shape{DType, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
......@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
return p;
}
};
template struct test_unbatched_gemm_1<migraphx::shape::float_type>;
template struct test_unbatched_gemm_1<migraphx::shape::half_type>;
template struct test_unbatched_gemm_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m1_shape{DType, {4, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
......@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
return p;
}
};
template struct test_unbatched_gemm_2<migraphx::shape::float_type>;
template struct test_unbatched_gemm_2<migraphx::shape::half_type>;
template struct test_unbatched_gemm_2<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment