Commit a7a686d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8ce6758a
...@@ -167,7 +167,7 @@ __device__ inline void dpp_reduce(float& x, sum) ...@@ -167,7 +167,7 @@ __device__ inline void dpp_reduce(float& x, sum)
} }
template <std::size_t N, class Op, class T, class F> template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{ {
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64]; MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
......
...@@ -23,26 +23,30 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -23,26 +23,30 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size); auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest(); type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ { auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[data_idx]; data_idx[axis] = j;
}); return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; auto batch_sum =
auto val = input[data_idx] - batch_max; block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
return ::exp(to_hip_type(val)); data_idx[axis] = j;
}); auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max; auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max;
idx.local_stride(batch_item_num, [&](auto j) { idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = input[data_idx] - log_batch_sum; output[data_idx] = input[data_idx] - log_batch_sum;
}); });
}); });
......
...@@ -24,25 +24,29 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -24,25 +24,29 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size); auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest(); type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ { auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[data_idx]; data_idx[axis] = j;
}); return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { auto batch_sum =
data_idx[axis] = j; block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[data_idx] - batch_max; data_idx[axis] = j;
return ::exp(to_hip_type(val)); auto val = input[data_idx] - batch_max;
}); return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) { idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j; data_idx[axis] = j;
auto val = input[data_idx] - batch_max; auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment