Commit 9f06859b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

final version of softmax that works

parent bc9eac75
......@@ -38,6 +38,7 @@ template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = 1; s < block_size; s *= 2)
{
const index_int index = 2 * s * tid;
......@@ -96,6 +97,7 @@ template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = 1; s < block_size; s *= 2)
{
const index_int index = 2 * s * tid;
......@@ -124,21 +126,16 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
// printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
}
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
// printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i];
// printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
}
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
// printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
......@@ -153,7 +150,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 128;
const index_int max_block_size = 1024;
const index_int block_size = compute_block_size(batch_item_num, max_block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
......@@ -165,6 +162,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{
int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size();
softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
arg.data(), batch_item_num, block_size, result.data());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment