Commit 8ce6758a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

futher factor softmax/logsoftmax gpu implementation

parent 01a4fde5
......@@ -73,7 +73,7 @@ __host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i)
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{
std::size_t groups = 1 + n / local;
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) {
......
......@@ -78,7 +78,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
for(std::size_t s = 1; s < idx.nlocal(); s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if(index < idx.nlocal())
if(index + s < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
......@@ -167,7 +167,7 @@ __device__ inline void dpp_reduce(float& x, sum)
}
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
......@@ -185,7 +185,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
type y = 0;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{
y += buffer[i];
y = op(y, buffer[i]);
}
return y;
}
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
......@@ -14,81 +14,37 @@ namespace device {
void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto lens = result.get_shape().lens();
auto batch_item_num = lens[axis];
auto batch_lens = lens;
batch_lens[axis] = 1;
auto lens = result.get_shape().lens();
auto batch_lens = lens;
std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// use one block for items in one batch.
const std::size_t max_block_size = 1024;
std::size_t block_size = 1;
while(block_size < max_block_size and block_size < batch_item_num)
{
block_size *= 2;
}
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
std::size_t thr_idx = idx.local;
std::size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
std::size_t remaining_item_num = batch_item_num;
std::size_t round_item_num =
(batch_item_num + block_size - 1) / block_size * block_size;
lds_data[max_block_size] = input[0];
for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce<type, max_op<type>>(
lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_max = lds_data[max_block_size];
__syncthreads();
lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num;
for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce<type, sum_op<type>>(
lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto log_batch_sum = ::log(to_hip_type(lds_data[max_block_size])) + batch_max;
for(std::size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max;
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
output[data_idx] = input[data_idx] - log_batch_sum;
}
});
});
});
}
......
......@@ -2,7 +2,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
......@@ -22,73 +22,29 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// use one block for items in one batch.
const std::size_t max_block_size = 1024;
std::size_t block_size = 1;
while(block_size < max_block_size and block_size < batch_item_num)
{
block_size *= 2;
}
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
std::size_t thr_idx = idx.local;
std::size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
std::size_t remaining_item_num = batch_item_num;
std::size_t round_item_num =
(batch_item_num + block_size - 1) / block_size * block_size;
lds_data[max_block_size] = input[0];
for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce<type, max_op<type>>(
lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_max = lds_data[max_block_size];
__syncthreads();
lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num;
for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce<type, sum_op<type>>(
lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_sum = lds_data[max_block_size];
for(std::size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
auto val = input[data_idx] - batch_max;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}
});
});
});
}
......
......@@ -586,7 +586,7 @@ struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program p;
auto x =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1028, 1, 25}});
p.add_instruction(migraphx::op::softmax{}, x);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment