Commit f74b77bc authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents b5be51ed 0d911822
...@@ -13,7 +13,6 @@ using I = ck::Number<N>; ...@@ -13,7 +13,6 @@ using I = ck::Number<N>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using I8 = int8_t;
template <typename Tuple> template <typename Tuple>
class TestSoftmax : public ck::TestSoftmax<Tuple> class TestSoftmax : public ck::TestSoftmax<Tuple>
...@@ -24,8 +23,7 @@ class TestSoftmax : public ck::TestSoftmax<Tuple> ...@@ -24,8 +23,7 @@ class TestSoftmax : public ck::TestSoftmax<Tuple>
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank // InDataType, AccDataType, OutDataType, Rank
std::tuple< F16, F32, F16, I<4>>, std::tuple< F16, F32, F16, I<4>>,
std::tuple< F32, F32, F32, I<4>>, std::tuple< F32, F32, F32, I<4>>
std::tuple< I8, F32, I8, I<4>>
>; >;
// clang-format on // clang-format on
......
...@@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test ...@@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test
int init_method = 1; // integer value initialization int init_method = 1; // integer value initialization
bool log = false; bool log = false;
std::vector<ck::index_t> strides; // intenionally empty, to get packed layout. std::vector<ck::index_t> strides; // intenionally empty, to get packed layout.
bool pass = ck::profiler::profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank>( bool pass = false;
verify_, init_method, log, bench_, in_length, strides, reduce_dims, alpha, beta);
if constexpr(Rank == 3)
{
if(reduce_dims.size() == 1)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 2)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 3)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
}
else if constexpr(Rank == 4)
{
if(reduce_dims.size() == 1)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 2)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 3)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 4)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 4>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
};
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment