Commit ee956e8e authored by carlushuang's avatar carlushuang
Browse files

add one more test case

parent 9b24c143
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck_tile/ops/reduce.hpp" #include "ck_tile/ops/reduce.hpp"
#ifndef TEST_TILE_REDUCE_VERBOSE #ifndef TEST_TILE_REDUCE_VERBOSE
#define TEST_TILE_REDUCE_VERBOSE 0 #define TEST_TILE_REDUCE_VERBOSE 1
#endif #endif
#define HIP_CALL(call) \ #define HIP_CALL(call) \
...@@ -30,13 +30,13 @@ ...@@ -30,13 +30,13 @@
#define BLOCK_SIZE 256 #define BLOCK_SIZE 256
template <int Rows, int Cols, typename DataType> template <int Rows, int Cols, typename DataType, int BytesPerIssue = 16>
__global__ void reduce_row(DataType* p_src, DataType* p_dst) __global__ void reduce_row(DataType* p_src, DataType* p_dst)
{ {
using namespace ck_tile; using namespace ck_tile;
// some constexpr vars // some constexpr vars
constexpr index_t vec = 16 / sizeof(DataType); constexpr index_t vec = BytesPerIssue / sizeof(DataType);
static_assert(Cols % vec == 0); static_assert(Cols % vec == 0);
constexpr index_t col_lanes = Cols / vec; constexpr index_t col_lanes = Cols / vec;
constexpr index_t warp_size = ck_tile::get_warp_size(); constexpr index_t warp_size = ck_tile::get_warp_size();
...@@ -109,8 +109,8 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst) ...@@ -109,8 +109,8 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
} }
} }
template <int Rows, int Cols, typename DataType> template <int Rows, int Cols, typename DataType, int BytesPerIssue = 16>
int test_tile_reduce() bool test_tile_reduce()
{ {
std::srand(std::time(nullptr)); std::srand(std::time(nullptr));
DataType* src = reinterpret_cast<DataType*>(malloc(Rows * Cols * sizeof(DataType))); DataType* src = reinterpret_cast<DataType*>(malloc(Rows * Cols * sizeof(DataType)));
...@@ -133,8 +133,8 @@ int test_tile_reduce() ...@@ -133,8 +133,8 @@ int test_tile_reduce()
constexpr int bdim = BLOCK_SIZE; constexpr int bdim = BLOCK_SIZE;
int gdim = 1; int gdim = 1;
reduce_row<Rows, Cols, DataType><<<gdim, bdim>>>(reinterpret_cast<DataType*>(dev_src), reduce_row<Rows, Cols, DataType, BytesPerIssue><<<gdim, bdim>>>(
reinterpret_cast<DataType*>(dev_dst)); reinterpret_cast<DataType*>(dev_src), reinterpret_cast<DataType*>(dev_dst));
HIP_CALL(hipMemcpy(dst, dev_dst, Rows * sizeof(DataType), hipMemcpyDeviceToHost)); HIP_CALL(hipMemcpy(dst, dev_dst, Rows * sizeof(DataType), hipMemcpyDeviceToHost));
...@@ -168,11 +168,14 @@ int test_tile_reduce() ...@@ -168,11 +168,14 @@ int test_tile_reduce()
free(src); free(src);
free(dst); free(dst);
return err_cnt == 0 ? 0 : -1; return err_cnt == 0 ? true : false;
} }
int main() int main()
{ {
int r = test_tile_reduce<32, 64, float>(); bool r = true;
return r; r &= test_tile_reduce<32, 64, float>();
r &= test_tile_reduce<32, 16, float, 4>();
return r ? 0 : -1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment