Commit 37f63907 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 45da3115
...@@ -16,10 +16,7 @@ namespace device { ...@@ -16,10 +16,7 @@ namespace device {
struct half2_sum struct half2_sum
{ {
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
{
return __hadd2(x, y);
}
}; };
inline __device__ __half2 hmax2(__half2 x, __half2 y) inline __device__ __half2 hmax2(__half2 x, __half2 y)
...@@ -33,16 +30,13 @@ inline __device__ __half2 hmax2(__half2 x, __half2 y) ...@@ -33,16 +30,13 @@ inline __device__ __half2 hmax2(__half2 x, __half2 y)
struct half2_max struct half2_max
{ {
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
{
return hmax2(x, y);
}
}; };
// in_data is in shared memory // in_data is in shared memory
template<class Op> template <class Op>
__device__ __half2 block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op) __device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{ {
for(index_int s = 1; s < block_size; s *= 2) for(index_int s = 1; s < block_size; s *= 2)
{ {
...@@ -60,7 +54,8 @@ __device__ __half2 block_reduce(__half2* buffer, index_int batch_item_num, index ...@@ -60,7 +54,8 @@ __device__ __half2 block_reduce(__half2* buffer, index_int batch_item_num, index
return op(lows2, highs2); return op(lows2, highs2);
} }
__global__ void softmax_kernel(void *data_in, index_int batch_item_num, index_int block_size, void* data_out) __global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{ {
__half2* input = reinterpret_cast<__half2*>(data_in); __half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
...@@ -71,32 +66,35 @@ __global__ void softmax_kernel(void *data_in, index_int batch_item_num, index_in ...@@ -71,32 +66,35 @@ __global__ void softmax_kernel(void *data_in, index_int batch_item_num, index_in
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num; __half2* in_data = buffer2 + batch_item_num;
int start = tid / block_size * batch_item_num; int start = tid / block_size * batch_item_num;
for (int i = tid; i < batch_item_num; i += block_size) for(int i = tid; i < batch_item_num; i += block_size)
{ {
auto d = input[i + start]; auto d = input[i + start];
in_data[i] = d; in_data[i] = d;
in_data_reduce[i] = d; in_data_reduce[i] = d;
} }
auto batch_max = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{}); auto batch_max =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
for (int i = tid; i < batch_item_num; i += block_size) for(int i = tid; i < batch_item_num; i += block_size)
{ {
in_data[i] = h2exp(__hsub2(in_data[i], batch_max)); in_data[i] = h2exp(__hsub2(in_data[i], batch_max));
in_data_reduce[i] = in_data[i]; in_data_reduce[i] = in_data[i];
} }
auto batch_sum = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); auto batch_sum =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for (int i = tid; i < batch_item_num; i += block_size) for(int i = tid; i < batch_item_num; i += block_size)
{ {
output[i + start] = __h2div(in_data[i], batch_sum); output[i + start] = __h2div(in_data[i], batch_sum);
} }
} }
// in_data is in shared memory // in_data is in shared memory
template<class Op> template <class Op>
__device__ __half block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op) __device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{ {
for(index_int s = 1; s < block_size; s *= 2) for(index_int s = 1; s < block_size; s *= 2)
{ {
...@@ -111,7 +109,8 @@ __device__ __half block_reduce2(__half* data, index_int batch_item_num, index_in ...@@ -111,7 +109,8 @@ __device__ __half block_reduce2(__half* data, index_int batch_item_num, index_in
return data[0]; return data[0];
} }
__global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_int block_size, void* data_out) __global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{ {
__half* input = reinterpret_cast<__half*>(data_in); __half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out); __half* output = reinterpret_cast<__half*>(data_out);
...@@ -121,7 +120,7 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i ...@@ -121,7 +120,7 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i
__half* in_data_reduce = buffer; __half* in_data_reduce = buffer;
__half* in_data = buffer + batch_item_num; __half* in_data = buffer + batch_item_num;
int start = tid / block_size * batch_item_num; int start = tid / block_size * batch_item_num;
for (int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
auto d = input[i + start]; auto d = input[i + start];
in_data[i] = d; in_data[i] = d;
...@@ -129,7 +128,7 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i ...@@ -129,7 +128,7 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i
} }
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{}); auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
for (int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max))); in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
...@@ -138,9 +137,9 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i ...@@ -138,9 +137,9 @@ __global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_i
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{}); auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
for (int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
output[i + start] = __float2half(__half2float(in_data[i])/__half2float(batch_sum)); output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
} }
} }
...@@ -160,11 +159,12 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -160,11 +159,12 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if(axis == batch_lens.size() - 1) if(axis == batch_lens.size() - 1)
{ {
auto in_type = result.get_shape().type(); auto in_type = result.get_shape().type();
if (in_type == shape::half_type and batch_item_num <= 2048) if(in_type == shape::half_type and batch_item_num <= 2048)
{ {
int block_num = batch_shape.elements(); int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size(); int shared_size = batch_item_num * 2 * result.get_shape().type_size();
softmax_kernel2<<<block_num, block_size, shared_size, stream>>>(arg.data(), batch_item_num, block_size, result.data()); softmax_kernel2<<<block_num, block_size, shared_size, stream>>>(
arg.data(), batch_item_num, block_size, result.data());
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment