"docs/_removed/examples.rst" did not exist on "abd164c2598d4cf19a081b4e5c1070de7bea8386"
Commit 1da02b0f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup softmax changes

parent ae59a3b1
......@@ -39,12 +39,11 @@ __device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = 1; s < block_size; s *= 2)
for(index_int s = block_size; s > 0; s >>= 1)
{
const index_int index = 2 * s * tid;
if(index + s < batch_item_num)
if(tid < s and tid + s < batch_item_num)
{
buffer[index] = op(buffer[index], buffer[index + s]);
buffer[tid] = op(buffer[tid], buffer[tid + s]);
}
__syncthreads();
}
......@@ -61,12 +60,11 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
__half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out);
batch_item_num /= 2;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = tid / block_size * batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
......@@ -98,12 +96,11 @@ __device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = 1; s < block_size; s *= 2)
for(index_int s = block_size / 2; s > 0; s >>= 1)
{
const index_int index = 2 * s * tid;
if(index + s < batch_item_num)
if(tid < s and tid + s < batch_item_num)
{
data[index] = op(data[index], data[index + s]);
data[tid] = op(data[tid], data[tid + s]);
}
__syncthreads();
}
......@@ -158,13 +155,13 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if(axis == batch_lens.size() - 1)
{
auto in_type = result.get_shape().type();
if(in_type == shape::half_type and batch_item_num <= 2048)
if(in_type == shape::half_type and batch_item_num <= 1024)
{
int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size();
softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
arg.data(), batch_item_num, block_size, result.data());
auto half2_block_size = block_size / 4;
softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
arg.data(), batch_item_num, half2_block_size, result.data());
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment