Commit 45da3115 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup code changes related to softmax

parent ea656c84
...@@ -59,8 +59,8 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n) ...@@ -59,8 +59,8 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2); __half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
auto x = __hmul2(sum, sqrt2); auto x = __hmul2(sum, sqrt2);
auto f2 = __half22float2(x); auto f2 = __half22float2(x);
f2.x = ::erf(f2.x); f2.x = ::erff(f2.x);
f2.y = ::erf(f2.y); f2.y = ::erff(f2.y);
auto h2 = __floats2half2_rn(f2.x, f2.y); auto h2 = __floats2half2_rn(f2.x, f2.y);
auto one = __float2half2_rn(1.0f); auto one = __float2half2_rn(1.0f);
......
...@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY); ...@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \ if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl; std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static index_int group_num_global = (1 << 20);
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
...@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl( ...@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)( hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) { [&](auto output, auto binput, auto... inputs) {
...@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) { hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
...@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl( ...@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
const index_int vec_size = 4; const index_int vec_size = 4;
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
const index_int bdim_vec_len = bdim_len / vec_size; const index_int bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)( hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
...@@ -234,7 +236,7 @@ void nary_double_broadcast_impl( ...@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride); auto broadcast_idx = create_broadcast_index(bdim_len, bdim_stride);
const index_int nlocal = 1024; const index_int nlocal = 1024;
const index_int nglobal = 256 * nlocal; const index_int nglobal = group_num_global * nlocal;
index_int nelements = result.get_shape().elements(); index_int nelements = result.get_shape().elements();
hip_visit_all(result, barg1, barg2, args...)( hip_visit_all(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) { [&](auto output, auto binput1, auto binput2, auto... inputs) {
......
...@@ -6,12 +6,144 @@ ...@@ -6,12 +6,144 @@
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
struct half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const
{
return __hadd2(x, y);
}
};
inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
auto fx2 = __half22float2(x);
auto fy2 = __half22float2(y);
auto fx = fx2.x > fy2.x ? fx2.x : fy2.x;
auto fy = fx2.y > fy2.y ? fx2.y : fy2.y;
return __floats2half2_rn(fx, fy);
}
struct half2_max
{
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const
{
return hmax2(x, y);
}
};
// in_data is in shared memory
template<class Op>
__device__ __half2 block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
for(index_int s = 1; s < block_size; s *= 2)
{
const index_int index = 2 * s * tid;
if(index + s < batch_item_num)
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
auto lows2 = __low2half2(buffer[0]);
auto highs2 = __high2half2(buffer[0]);
return op(lows2, highs2);
}
__global__ void softmax_kernel(void *data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out);
batch_item_num /= 2;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = tid / block_size * batch_item_num;
for (int i = tid; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
}
auto batch_max = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
for (int i = tid; i < batch_item_num; i += block_size)
{
in_data[i] = h2exp(__hsub2(in_data[i], batch_max));
in_data_reduce[i] = in_data[i];
}
auto batch_sum = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for (int i = tid; i < batch_item_num; i += block_size)
{
output[i + start] = __h2div(in_data[i], batch_sum);
}
}
// in_data is in shared memory
template<class Op>
__device__ __half block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
for(index_int s = 1; s < block_size; s *= 2)
{
const index_int index = 2 * s * tid;
if(index + s < batch_item_num)
{
data[index] = op(data[index], data[index + s]);
}
__syncthreads();
}
return data[0];
}
__global__ void softmax_kernel2(void *data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out);
int tid = blockDim.x * blockIdx.x + threadIdx.x;
extern MIGRAPHX_DEVICE_SHARED __half buffer[];
__half* in_data_reduce = buffer;
__half* in_data = buffer + batch_item_num;
int start = tid / block_size * batch_item_num;
for (int i = threadIdx.x; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
}
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
for (int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i];
}
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
for (int i = threadIdx.x; i < batch_item_num; i += block_size)
{
output[i + start] = __float2half(__half2float(in_data[i])/__half2float(batch_sum));
}
}
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto batch_lens = result.get_shape().lens(); auto batch_lens = result.get_shape().lens();
...@@ -27,25 +159,35 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -27,25 +159,35 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if(axis == batch_lens.size() - 1) if(axis == batch_lens.size() - 1)
{ {
gs_launch(stream, batch_shape.elements() * block_size, block_size)( auto in_type = result.get_shape().type();
[=](auto i, auto idx) __device__ { if (in_type == shape::half_type and batch_item_num <= 2048)
auto start_loc = i / block_size * batch_item_num; {
auto batch_max = block_reduce<max_block_size>( int block_num = batch_shape.elements();
idx, max{}, init, batch_item_num, [&](auto j) __device__ { int shared_size = batch_item_num * 2 * result.get_shape().type_size();
return input[start_loc + j]; softmax_kernel2<<<block_num, block_size, shared_size, stream>>>(arg.data(), batch_item_num, block_size, result.data());
}); }
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto start_loc = i / block_size * batch_item_num;
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[start_loc + j];
});
auto batch_sum = block_reduce<max_block_size>( auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max; auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val)); return ::exp(to_hip_type(val));
}); });
idx.local_stride(batch_item_num, [&](auto j) __device__ { idx.local_stride(batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max; auto val = input[start_loc + j] - batch_max;
output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum; output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
});
}); });
}); }
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment