Commit b9575730 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix build errors.

parent d5a32cd2
......@@ -30,30 +30,28 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2;
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)];
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -69,14 +67,13 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -86,9 +83,8 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
output[index] = input[index] - log_batch_sum;
data_idx[axis] = i;
output[data_idx] = input[data_idx] - log_batch_sum;
}
});
});
......
......@@ -30,31 +30,29 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2;
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx;
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)];
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -70,14 +68,13 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -86,10 +83,9 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
auto val = input[index] - batch_max;
output[index] = ::exp(to_hip_type(val)) / batch_sum;
data_idx[axis] = i;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment