Commit a7a686d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8ce6758a
......@@ -23,17 +23,21 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
......
......@@ -24,17 +24,21 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)([=](auto i, auto idx) __device__ {
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(idx, max{}, init, batch_item_num, [&](auto j) __device__ {
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
auto batch_sum = block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment