Commit f74b77bc authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents b5be51ed 0d911822
......@@ -13,7 +13,6 @@ using I = ck::Number<N>;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
template <typename Tuple>
class TestSoftmax : public ck::TestSoftmax<Tuple>
......@@ -24,8 +23,7 @@ class TestSoftmax : public ck::TestSoftmax<Tuple>
using KernelTypes = ::testing::Types<
// InDataType, AccDataType, OutDataType, Rank
std::tuple< F16, F32, F16, I<4>>,
std::tuple< F32, F32, F32, I<4>>,
std::tuple< I8, F32, I8, I<4>>
std::tuple< F32, F32, F32, I<4>>
>;
// clang-format on
......
......@@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test
int init_method = 1; // integer value initialization
bool log = false;
std::vector<ck::index_t> strides; // intenionally empty, to get packed layout.
bool pass = ck::profiler::profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank>(
verify_, init_method, log, bench_, in_length, strides, reduce_dims, alpha, beta);
bool pass = false;
if constexpr(Rank == 3)
{
if(reduce_dims.size() == 1)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 2)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 3)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
}
else if constexpr(Rank == 4)
{
if(reduce_dims.size() == 1)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 1>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 2)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 2>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 3)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 3>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
else if(reduce_dims.size() == 4)
pass = ck::profiler::
profile_softmax_impl<InDataType, AccDataType, OutDataType, Rank, 4>(verify_,
init_method,
log,
bench_,
in_length,
strides,
reduce_dims,
alpha,
beta);
};
EXPECT_TRUE(pass);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment