Commit 1da02b0f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup softmax changes

parent ae59a3b1
...@@ -39,12 +39,11 @@ __device__ __half2 ...@@ -39,12 +39,11 @@ __device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op) block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{ {
__syncthreads(); __syncthreads();
for(index_int s = 1; s < block_size; s *= 2) for(index_int s = block_size; s > 0; s >>= 1)
{ {
const index_int index = 2 * s * tid; if(tid < s and tid + s < batch_item_num)
if(index + s < batch_item_num)
{ {
buffer[index] = op(buffer[index], buffer[index + s]); buffer[tid] = op(buffer[tid], buffer[tid + s]);
} }
__syncthreads(); __syncthreads();
} }
...@@ -61,12 +60,11 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo ...@@ -61,12 +60,11 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
__half2* input = reinterpret_cast<__half2*>(data_in); __half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
batch_item_num /= 2; batch_item_num /= 2;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num; __half2* in_data = buffer2 + batch_item_num;
int start = tid / block_size * batch_item_num; int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
auto d = input[i + start]; auto d = input[i + start];
...@@ -98,12 +96,11 @@ __device__ __half ...@@ -98,12 +96,11 @@ __device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op) block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{ {
__syncthreads(); __syncthreads();
for(index_int s = 1; s < block_size; s *= 2) for(index_int s = block_size / 2; s > 0; s >>= 1)
{ {
const index_int index = 2 * s * tid; if(tid < s and tid + s < batch_item_num)
if(index + s < batch_item_num)
{ {
data[index] = op(data[index], data[index + s]); data[tid] = op(data[tid], data[tid + s]);
} }
__syncthreads(); __syncthreads();
} }
...@@ -158,13 +155,13 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -158,13 +155,13 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if(axis == batch_lens.size() - 1) if(axis == batch_lens.size() - 1)
{ {
auto in_type = result.get_shape().type(); auto in_type = result.get_shape().type();
if(in_type == shape::half_type and batch_item_num <= 2048) if(in_type == shape::half_type and batch_item_num <= 1024)
{ {
int block_num = batch_shape.elements(); int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size(); int shared_size = batch_item_num * 2 * result.get_shape().type_size();
auto half2_block_size = block_size / 4;
softmax_kernel<<<block_num, block_size, shared_size, stream>>>( softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
arg.data(), batch_item_num, block_size, result.data()); arg.data(), batch_item_num, half2_block_size, result.data());
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment