Commit 23a18b2b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in softmax half2 implementation

parent 08818705
...@@ -66,7 +66,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo ...@@ -66,7 +66,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num; __half2* in_data = buffer2 + batch_item_num;
int start = tid / block_size * batch_item_num; int start = tid / block_size * batch_item_num;
for(int i = tid; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
auto d = input[i + start]; auto d = input[i + start];
in_data[i] = d; in_data[i] = d;
...@@ -76,7 +76,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo ...@@ -76,7 +76,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
auto batch_max = auto batch_max =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{}); block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
for(int i = tid; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = h2exp(__hsub2(in_data[i], batch_max)); in_data[i] = h2exp(__hsub2(in_data[i], batch_max));
in_data_reduce[i] = in_data[i]; in_data_reduce[i] = in_data[i];
...@@ -85,7 +85,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo ...@@ -85,7 +85,7 @@ softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, vo
auto batch_sum = auto batch_sum =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for(int i = tid; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
output[i + start] = __h2div(in_data[i], batch_sum); output[i + start] = __h2div(in_data[i], batch_sum);
} }
...@@ -163,7 +163,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -163,7 +163,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{ {
int block_num = batch_shape.elements(); int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size(); int shared_size = batch_item_num * 2 * result.get_shape().type_size();
softmax_kernel2<<<block_num, block_size, shared_size, stream>>>( softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
arg.data(), batch_item_num, block_size, result.data()); arg.data(), batch_item_num, block_size, result.data());
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment