Unverified Commit 2e337c7f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Softmax perf optimization (#1014)

Changed the number of threads in a block from 256 to 128
Increased the max number of blocks in the kernel from 256 to 1M.
For the case that the axis is the last dimension, we removed the computation of index since it is not required.

With these change, we can get about 2x speedup compared to the develop branch for the softmax op used in the BertSquad model.
parent e758d457
...@@ -76,7 +76,8 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype( ...@@ -76,7 +76,8 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{ {
index_int groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(256, groups) * local; // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) __device__ { launch(stream, nglobal, local)([=](auto idx) __device__ {
......
...@@ -20,23 +20,46 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -20,23 +20,46 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 256; const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest(); type init = lowest();
if(axis == batch_lens.size() - 1)
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto start_loc = i / block_size * batch_item_num;
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[start_loc + j];
});
auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
});
});
}
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
auto batch_max = block_reduce<max_block_size>( auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ { idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
return input[data_idx]; return input[data_idx];
}); });
auto batch_sum = auto batch_sum = block_reduce<max_block_size>(
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
auto val = input[data_idx] - batch_max; auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val)); return ::exp(to_hip_type(val));
...@@ -48,6 +71,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -48,6 +71,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}); });
}); });
}
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment