Commit e9a3bdaa authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents d7dfe995 a09dc502
......@@ -26,16 +26,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_reverse : verify_program<test_reverse>
template <migraphx::shape::type_t DType>
struct test_reverse : verify_program<test_reverse<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 16}};
migraphx::shape s{DType, {4, 16}};
auto a0 = mm->add_parameter("data", s);
std::vector<int64_t> axis = {0};
mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0);
return p;
}
};
template struct test_reverse<migraphx::shape::float_type>;
template struct test_reverse<migraphx::shape::half_type>;
template struct test_reverse<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -31,7 +31,8 @@
#include <migraphx/op/common.hpp>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
template <migraphx::shape::type_t DType>
struct test_rnn_sql_1 : verify_program<test_rnn_sql_1<DType>>
{
migraphx::program create_program() const
{
......@@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape in_shape{DType, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{DType, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{DType, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{DType, {num_dirct, 2 * hidden_size}};
migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ih_shape{DType, {num_dirct, batch_size, hidden_size}};
auto seq = mm->add_parameter("seq", in_shape);
auto w = mm->add_parameter("w", w_shape);
......@@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program<test_rnn_sql_1>
}
std::string section() const { return "rnn"; }
};
template struct test_rnn_sql_1<migraphx::shape::float_type>;
template struct test_rnn_sql_1<migraphx::shape::half_type>;
template struct test_rnn_sql_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter0 : verify_program<test_scatter0>
template <migraphx::shape::type_t DType>
struct test_scatter0 : verify_program<test_scatter0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 3}};
migraphx::shape sd{DType, {3, 3}};
migraphx::shape si{migraphx::shape::int32_type, {2, 3}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 3}};
migraphx::shape su{DType, {2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
......@@ -47,3 +48,7 @@ struct test_scatter0 : verify_program<test_scatter0>
return p;
}
};
template struct test_scatter0<migraphx::shape::float_type>;
template struct test_scatter0<migraphx::shape::half_type>;
template struct test_scatter0<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,13 +27,14 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_topk_0 : verify_program<test_topk_0>
template <migraphx::shape::type_t DType>
struct test_topk_0 : verify_program<test_topk_0<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 5}};
migraphx::shape s{DType, {3, 5}};
auto data = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 1}}), data);
......@@ -43,3 +44,7 @@ struct test_topk_0 : verify_program<test_topk_0>
return p;
}
};
template struct test_topk_0<migraphx::shape::float_type>;
template struct test_topk_0<migraphx::shape::half_type>;
template struct test_topk_0<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}};
migraphx::shape m1_shape{DType, {2, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
migraphx::shape m3_shape{DType, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
......@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
return p;
}
};
template struct test_unbatched_gemm_1<migraphx::shape::float_type>;
template struct test_unbatched_gemm_1<migraphx::shape::half_type>;
template struct test_unbatched_gemm_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}};
migraphx::shape m1_shape{DType, {4, 32, 64}};
migraphx::shape m2_shape{DType, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
......@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
return p;
}
};
template struct test_unbatched_gemm_2<migraphx::shape::float_type>;
template struct test_unbatched_gemm_2<migraphx::shape::half_type>;
template struct test_unbatched_gemm_2<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment