Commit ea932b63 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add std namespace for size_t

parent 070d4904
...@@ -533,7 +533,7 @@ struct cpu_softmax ...@@ -533,7 +533,7 @@ struct cpu_softmax
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
size_t n_dims = batch_lens[op.axis]; std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -544,26 +544,26 @@ struct cpu_softmax ...@@ -544,26 +544,26 @@ struct cpu_softmax
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto idx = batch_shape.multi(i); auto idx = batch_shape.multi(i);
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end())); batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
} }
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
size_t index = output_shape.index(idx); std::size_t index = output_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]); output[index] = std::exp(input[index] - batch_max[i]);
} }
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
batch_sum[i] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
} }
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i]; output(idx.begin(), idx.end()) /= batch_sum[i];
...@@ -591,7 +591,7 @@ struct cpu_logsoftmax ...@@ -591,7 +591,7 @@ struct cpu_logsoftmax
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
size_t n_dims = batch_lens[op.axis]; std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -605,20 +605,20 @@ struct cpu_logsoftmax ...@@ -605,20 +605,20 @@ struct cpu_logsoftmax
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto idx = batch_shape.multi(i); auto idx = batch_shape.multi(i);
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end())); batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
} }
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
size_t index = output_shape.index(idx); std::size_t index = output_shape.index(idx);
output[index] = input[index] - batch_max[i]; output[index] = input[index] - batch_max[i];
} }
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end())); batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
...@@ -626,7 +626,7 @@ struct cpu_logsoftmax ...@@ -626,7 +626,7 @@ struct cpu_logsoftmax
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
for(size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i]; output(idx.begin(), idx.end()) -= batch_sum[i];
......
...@@ -23,26 +23,26 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -23,26 +23,26 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 1024; const std::size_t max_block_size = 1024;
size_t block_size = 1; std::size_t block_size = 1;
while(block_size < max_block_size and block_size < batch_item_num) while(block_size < max_block_size and block_size < batch_item_num)
{ {
block_size *= 2; block_size *= 2;
} }
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; std::size_t thr_idx = idx.local;
size_t blk_idx = idx.group; std::size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx); auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; std::size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; std::size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[max_block_size] = input[0]; lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
...@@ -62,7 +62,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -62,7 +62,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
lds_data[max_block_size] = 0; lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num; remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
...@@ -81,7 +81,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -81,7 +81,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
auto log_batch_sum = ::log(to_hip_type(lds_data[max_block_size])) + batch_max; auto log_batch_sum = ::log(to_hip_type(lds_data[max_block_size])) + batch_max;
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(std::size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
output[data_idx] = input[data_idx] - log_batch_sum; output[data_idx] = input[data_idx] - log_batch_sum;
......
...@@ -17,32 +17,32 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -17,32 +17,32 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t batch_item_num = lens[axis]; std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 1024; const std::size_t max_block_size = 1024;
size_t block_size = 1; std::size_t block_size = 1;
while(block_size < max_block_size and block_size < batch_item_num) while(block_size < max_block_size and block_size < batch_item_num)
{ {
block_size *= 2; block_size *= 2;
} }
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; std::size_t thr_idx = idx.local;
size_t blk_idx = idx.group; std::size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
auto batch_idx = batch.multi(blk_idx); auto batch_idx = batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; std::size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; std::size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[max_block_size] = input[0]; lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
...@@ -63,7 +63,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -63,7 +63,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
lds_data[max_block_size] = 0; lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num; remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
{ {
...@@ -81,7 +81,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -81,7 +81,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
} }
auto batch_sum = lds_data[max_block_size]; auto batch_sum = lds_data[max_block_size];
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(std::size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
auto val = input[data_idx] - batch_max; auto val = input[data_idx] - batch_max;
......
...@@ -12,13 +12,13 @@ namespace device { ...@@ -12,13 +12,13 @@ namespace device {
template <class T> template <class T>
inline __device__ void inline __device__ void
reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t max_index) reduce_max(T* data_ptr, std::size_t block_size, std::size_t thr_idx, std::size_t item_num, std::size_t max_index)
{ {
while(true) while(true)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
auto size = item_num / 2; auto size = item_num / 2;
for(size_t i = thr_idx; i < size; i += block_size) for(std::size_t i = thr_idx; i < size; i += block_size)
{ {
data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride])); data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
} }
...@@ -40,13 +40,13 @@ reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size ...@@ -40,13 +40,13 @@ reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size
template <class T> template <class T>
inline __device__ void inline __device__ void
reduce_min(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index) reduce_min(T* data_ptr, std::size_t block_size, std::size_t thr_idx, std::size_t item_num, std::size_t min_index)
{ {
while(true) while(true)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
auto size = item_num / 2; auto size = item_num / 2;
for(size_t i = thr_idx; i < size; i += block_size) for(std::size_t i = thr_idx; i < size; i += block_size)
{ {
data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride])); data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
} }
...@@ -68,13 +68,13 @@ reduce_min(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size ...@@ -68,13 +68,13 @@ reduce_min(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size
template <class T> template <class T>
inline __device__ void inline __device__ void
reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t sum_index) reduce_sum(T* data_ptr, std::size_t block_size, std::size_t thr_idx, std::size_t item_num, std::size_t sum_index)
{ {
while(true) while(true)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
auto size = item_num / 2; auto size = item_num / 2;
for(size_t i = thr_idx; i < size; i += block_size) for(std::size_t i = thr_idx; i < size; i += block_size)
{ {
data_ptr[i] += data_ptr[i + stride]; data_ptr[i] += data_ptr[i + stride];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment