Commit b9575730 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix build errors.

parent d5a32cd2
...@@ -30,30 +30,28 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -30,30 +30,28 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2; block_size *= 2;
} }
launch( launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx); auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0]; lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)]; lds_data[thr_idx] = input[data_idx];
} }
__syncthreads(); __syncthreads();
auto item_num = auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num); reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -69,14 +67,13 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -69,14 +67,13 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max; lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx])); lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
__syncthreads(); __syncthreads();
auto item_num = auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num); reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -86,9 +83,8 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -86,9 +83,8 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); output[data_idx] = input[data_idx] - log_batch_sum;
output[index] = input[index] - log_batch_sum;
} }
}); });
}); });
......
...@@ -30,31 +30,29 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -30,31 +30,29 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
block_size *= 2; block_size *= 2;
} }
launch( launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx); auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0]; lds_data[block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)]; lds_data[thr_idx] = input[data_idx];
} }
__syncthreads(); __syncthreads();
auto item_num = auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num); reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -70,14 +68,13 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -70,14 +68,13 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if(i < batch_item_num) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input[desc_data.linear(data_idx)] - batch_max; lds_data[thr_idx] = input[data_idx] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx])); lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
__syncthreads(); __syncthreads();
auto item_num = auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num); reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -86,10 +83,9 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -86,10 +83,9 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); auto val = input[data_idx] - batch_max;
auto val = input[index] - batch_max; output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
output[index] = ::exp(to_hip_type(val)) / batch_sum;
} }
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment