Commit bc9eac75 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

version that softmax half2 works

parent 23a18b2b
...@@ -50,7 +50,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const ...@@ -50,7 +50,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
} }
else else
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
} }
} }
......
...@@ -114,28 +114,30 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v ...@@ -114,28 +114,30 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
{ {
__half* input = reinterpret_cast<__half*>(data_in); __half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out); __half* output = reinterpret_cast<__half*>(data_out);
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half buffer[]; extern MIGRAPHX_DEVICE_SHARED __half buffer[];
__half* in_data_reduce = buffer; __half* in_data_reduce = buffer;
__half* in_data = buffer + batch_item_num; __half* in_data = buffer + batch_item_num;
int start = tid / block_size * batch_item_num; int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
auto d = input[i + start]; auto d = input[i + start];
in_data[i] = d; in_data[i] = d;
in_data_reduce[i] = d; in_data_reduce[i] = d;
// printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
} }
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{}); auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
// printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max))); in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i]; in_data_reduce[i] = in_data[i];
// printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
} }
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{}); auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
// printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment