Commit 25d7fde8 authored by gaoqiong's avatar gaoqiong
Browse files

lite

parent 8439d29f
...@@ -10,7 +10,7 @@ namespace rocm { ...@@ -10,7 +10,7 @@ namespace rocm {
#ifdef USE_ROCM #ifdef USE_ROCM
constexpr int num_elements_per_thread = 2; constexpr int num_elements_per_thread = 2;
constexpr int num_threads_per_block = 512; constexpr int num_threads_per_block = 256;
#else #else
constexpr int num_elements_per_thread = GridDim::maxElementsPerThread; constexpr int num_elements_per_thread = GridDim::maxElementsPerThread;
constexpr int num_threads_per_block = GridDim::maxThreadsPerBlock; constexpr int num_threads_per_block = GridDim::maxThreadsPerBlock;
......
...@@ -133,6 +133,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi ...@@ -133,6 +133,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
const gsl::span<const int64_t>& input_dims, const gsl::span<const int64_t>& input_dims,
const gsl::span<const size_t>& permutations, const gsl::span<const size_t>& permutations,
dim3& grid_size, dim3& block_size) { dim3& grid_size, dim3& block_size) {
//printf("maxThreadsPerBlock:%d \n",prop.maxThreadsPerBlock);
if (rank == 4 && if (rank == 4 &&
// the permutations is not on the last dimension. // the permutations is not on the last dimension.
permutations[3] == 3) { permutations[3] == 3) {
...@@ -142,7 +143,9 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi ...@@ -142,7 +143,9 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
// dims[2]: block.y + grid.x // dims[2]: block.y + grid.x
// dims[1]: grid.y // dims[1]: grid.y
// dims[0]: grid.z // dims[0]: grid.z
if (input_dims[3] / num_elements_per_thread <= prop.maxThreadsPerBlock &&
const int maxThreadsPerBlock = prop.maxThreadsPerBlock;
if (input_dims[3] / num_elements_per_thread <= maxThreadsPerBlock &&
(input_dims[3] % num_elements_per_thread) == 0 && (input_dims[3] % num_elements_per_thread) == 0 &&
input_dims[1] <= prop.maxGridSize[1] && input_dims[1] <= prop.maxGridSize[1] &&
input_dims[0] <= prop.maxGridSize[2]) { input_dims[0] <= prop.maxGridSize[2]) {
...@@ -150,7 +153,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi ...@@ -150,7 +153,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
// 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock // 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock
// 2. block_size_y * num_block_ext >= input_dims[2] // 2. block_size_y * num_block_ext >= input_dims[2]
int64_t block_size_x = input_dims[3] / num_elements_per_thread; int64_t block_size_x = input_dims[3] / num_elements_per_thread;
int64_t max_block_size_y = prop.maxThreadsPerBlock / block_size_x; int64_t max_block_size_y = maxThreadsPerBlock / block_size_x;
int64_t block_size_y = min(input_dims[2], max_block_size_y); int64_t block_size_y = min(input_dims[2], max_block_size_y);
int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y); int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y);
...@@ -255,14 +258,15 @@ bool CanDoTranspose4DParallelizeOneElementPerThread(const hipDeviceProp_t& prop, ...@@ -255,14 +258,15 @@ bool CanDoTranspose4DParallelizeOneElementPerThread(const hipDeviceProp_t& prop,
// dims[2]: block.y + grid.x // dims[2]: block.y + grid.x
// dims[1]: grid.y // dims[1]: grid.y
// dims[0]: grid.z // dims[0]: grid.z
if (input_dims[3] <= prop.maxThreadsPerBlock && const int maxThreadsPerBlock = prop.maxThreadsPerBlock;
if (input_dims[3] <= maxThreadsPerBlock &&
input_dims[1] <= prop.maxGridSize[1] && input_dims[1] <= prop.maxGridSize[1] &&
input_dims[0] <= prop.maxGridSize[2]) { input_dims[0] <= prop.maxGridSize[2]) {
// There are 2 constrains when luanching the kernels // There are 2 constrains when luanching the kernels
// 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock // 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock
// 2. block_size_y * num_block_ext >= input_dims[2] // 2. block_size_y * num_block_ext >= input_dims[2]
int64_t block_size_x = input_dims[3]; int64_t block_size_x = input_dims[3];
int64_t max_block_size_y = prop.maxThreadsPerBlock / block_size_x; int64_t max_block_size_y = maxThreadsPerBlock / block_size_x;
int64_t block_size_y = std::min(input_dims[2], max_block_size_y); int64_t block_size_y = std::min(input_dims[2], max_block_size_y);
int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y); int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment