Commit 25d7fde8 authored by gaoqiong's avatar gaoqiong
Browse files

lite

parent 8439d29f
......@@ -10,7 +10,7 @@ namespace rocm {
#ifdef USE_ROCM
constexpr int num_elements_per_thread = 2;
constexpr int num_threads_per_block = 512;
constexpr int num_threads_per_block = 256;
#else
constexpr int num_elements_per_thread = GridDim::maxElementsPerThread;
constexpr int num_threads_per_block = GridDim::maxThreadsPerBlock;
......
......@@ -133,6 +133,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
const gsl::span<const int64_t>& input_dims,
const gsl::span<const size_t>& permutations,
dim3& grid_size, dim3& block_size) {
//printf("maxThreadsPerBlock:%d \n",prop.maxThreadsPerBlock);
if (rank == 4 &&
// the permutations is not on the last dimension.
permutations[3] == 3) {
......@@ -142,7 +143,9 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
// dims[2]: block.y + grid.x
// dims[1]: grid.y
// dims[0]: grid.z
if (input_dims[3] / num_elements_per_thread <= prop.maxThreadsPerBlock &&
const int maxThreadsPerBlock = prop.maxThreadsPerBlock;
if (input_dims[3] / num_elements_per_thread <= maxThreadsPerBlock &&
(input_dims[3] % num_elements_per_thread) == 0 &&
input_dims[1] <= prop.maxGridSize[1] &&
input_dims[0] <= prop.maxGridSize[2]) {
......@@ -150,7 +153,7 @@ bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const hi
// 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock
// 2. block_size_y * num_block_ext >= input_dims[2]
int64_t block_size_x = input_dims[3] / num_elements_per_thread;
int64_t max_block_size_y = prop.maxThreadsPerBlock / block_size_x;
int64_t max_block_size_y = maxThreadsPerBlock / block_size_x;
int64_t block_size_y = min(input_dims[2], max_block_size_y);
int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y);
......@@ -255,14 +258,15 @@ bool CanDoTranspose4DParallelizeOneElementPerThread(const hipDeviceProp_t& prop,
// dims[2]: block.y + grid.x
// dims[1]: grid.y
// dims[0]: grid.z
if (input_dims[3] <= prop.maxThreadsPerBlock &&
const int maxThreadsPerBlock = prop.maxThreadsPerBlock;
if (input_dims[3] <= maxThreadsPerBlock &&
input_dims[1] <= prop.maxGridSize[1] &&
input_dims[0] <= prop.maxGridSize[2]) {
// There are 2 constrains when luanching the kernels
// 1. block_size_x * block_size_y <= prop.maxThreadsPerBlock
// 2. block_size_y * num_block_ext >= input_dims[2]
int64_t block_size_x = input_dims[3];
int64_t max_block_size_y = prop.maxThreadsPerBlock / block_size_x;
int64_t max_block_size_y = maxThreadsPerBlock / block_size_x;
int64_t block_size_y = std::min(input_dims[2], max_block_size_y);
int64_t num_block_ext = CeilDiv(input_dims[2], block_size_y);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment