Commit b9575730 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix build errors.

parent d5a32cd2
......@@ -30,8 +30,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2;
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
......@@ -48,12 +47,11 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)];
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -69,14 +67,13 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -87,8 +84,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
output[index] = input[index] - log_batch_sum;
output[data_idx] = input[data_idx] - log_batch_sum;
}
});
});
......
......@@ -30,8 +30,7 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2;
}
launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
......@@ -48,13 +47,12 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)];
lds_data[thr_idx] = input[data_idx];
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -70,14 +68,13 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads();
auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size;
......@@ -87,9 +84,8 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
data_idx[axis] = i;
size_t index = desc_data.linear(data_idx);
auto val = input[index] - batch_max;
output[index] = ::exp(to_hip_type(val)) / batch_sum;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment