Commit a7a686d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8ce6758a
......@@ -167,7 +167,7 @@ __device__ inline void dpp_reduce(float& x, sum)
}
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
......
......@@ -23,26 +23,30 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max;
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
output[data_idx] = input[data_idx] - log_batch_sum;
});
});
......
......@@ -24,25 +24,29 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment