Commit bc9eac75 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

version that softmax half2 works

parent 23a18b2b
......@@ -50,7 +50,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
}
}
......
......@@ -114,28 +114,30 @@ softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, v
{
__half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out);
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half buffer[];
__half* in_data_reduce = buffer;
__half* in_data = buffer + batch_item_num;
int start = tid / block_size * batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
// printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
}
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
// printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i];
// printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
}
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
// printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment