Commit 7a7fd9a0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

resolve merge conflict

parents 522b1da2 51f264a6
...@@ -586,19 +586,19 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -586,19 +586,19 @@ struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program p; migraphx::program p;
auto x = auto x =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1028, 1, 25}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraphx::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{}, x);
return p; return p;
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis>> struct test_softmax : verify_program<test_softmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 4, 1026, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -606,10 +606,14 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -606,10 +606,14 @@ struct test_softmax : verify_program<test_softmax<Axis>>
} }
}; };
template struct test_softmax<0>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<2>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3>; template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis> template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>> struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
...@@ -3368,13 +3372,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3368,13 +3372,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {17, 4, 1025, 6}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3382,26 +3386,16 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> ...@@ -3382,26 +3386,16 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
} }
}; };
template struct test_logsoftmax<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1>; template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2>; template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3>; template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template <int Axis> template struct test_logsoftmax<3, migraphx::shape::double_type>;
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> template struct test_logsoftmax<1, migraphx::shape::half_type>;
{ template struct test_logsoftmax<0, migraphx::shape::half_type>;
migraphx::program create_program() const template struct test_logsoftmax<2, migraphx::shape::half_type>;
{ template struct test_logsoftmax<3, migraphx::shape::half_type>;
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax_1<0>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment