Commit b8782a5f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further optimization of the softmax and logsoftmax operator.

parent 5384c7d7
...@@ -530,17 +530,26 @@ struct cpu_softmax ...@@ -530,17 +530,26 @@ struct cpu_softmax
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T> std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{ {
idx[axis] = 0; std::vector<std::size_t> indices(s.lens().size());
return batch_shape.index(idx); std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -548,27 +557,35 @@ struct cpu_softmax ...@@ -548,27 +557,35 @@ struct cpu_softmax
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis); par_for(batch_shape.elements(), [&](auto i) {
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); auto idx = compute_batch_indices(i, batch_shape);
});
shape_for_each(output_shape, [&](auto idx) { for (size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) = idx[op.axis] = j;
std::exp(input(idx.begin(), idx.end()) - batch_max[index]); batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}); }
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for (size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += output(idx.begin(), idx.end()); size_t index = output_shape.index(idx);
}); output[index] = std::exp(input[index] - batch_max[i]);
}
shape_for_each(output_shape, [&](auto idx) { for (size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) /= batch_sum[index]; idx[op.axis] = j;
}); batch_sum[i] += output(idx.begin(), idx.end());
}
for (size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i];
}
});
}); });
return result; return result;
...@@ -588,48 +605,65 @@ struct cpu_logsoftmax ...@@ -588,48 +605,65 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T> std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
std::size_t compute_batch_index(T idx, const shape& batch_shape, int axis) const
{ {
idx[axis] = 0; std::vector<std::size_t> indices(s.lens().size());
return batch_shape.index(idx); std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
// use a parallel implementation to acheive better performance
// one thread for one batch
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) { par_for(batch_shape.elements(), [&](auto i) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto idx = compute_batch_indices(i, batch_shape);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; for (size_t j = 0; j < n_dims; ++j)
}); {
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for (size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); size_t index = output_shape.index(idx);
}); output[index] = input[index] - batch_max[i];
}
for (size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
}
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) { for (size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) -= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i];
}
}); });
}); });
......
...@@ -18,7 +18,7 @@ argument logsoftmax(hipStream_t stream, ...@@ -18,7 +18,7 @@ argument logsoftmax(hipStream_t stream,
{ {
auto lens = output_shape.lens(); auto lens = output_shape.lens();
auto num_in_batch = lens[axis]; auto n_dims = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens}; migraphx::shape batch_shape{output_shape.type(), batch_lens};
...@@ -33,7 +33,13 @@ argument logsoftmax(hipStream_t stream, ...@@ -33,7 +33,13 @@ argument logsoftmax(hipStream_t stream,
// use one block for items in one batch. // use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as // opt 1, load all data to lds then use the same approach as
// the current optimization // the current optimization
const size_t block_size = 1024; const size_t max_block_size = 1024;
size_t block_size = 1;
while (block_size < max_block_size and block_size < n_dim)
{
block_size *= 2;
}
launch( launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
...@@ -42,17 +48,20 @@ argument logsoftmax(hipStream_t stream, ...@@ -42,17 +48,20 @@ argument logsoftmax(hipStream_t stream,
// all data can be loaded to the lds once, so all operations are // all data can be loaded to the lds once, so all operations are
// done in lds // done in lds
MIGRAPHX_DEVICE_SHARED type lds_data[block_size + 2]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t item_num = num_in_batch; size_t item_num = n_dims;
size_t thread_num = (n_dims + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0]; lds_data[block_size] = input_ptr[0];
for(size_t i = thr_idx; i < num_in_batch; i += block_size) for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
data_idx[axis] = i; if (i < n_dims)
lds_data[i] = input_ptr[desc_data.linear(data_idx)]; {
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
}
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
...@@ -85,13 +94,16 @@ argument logsoftmax(hipStream_t stream, ...@@ -85,13 +94,16 @@ argument logsoftmax(hipStream_t stream,
const size_t block_size1 = block_size + 1; const size_t block_size1 = block_size + 1;
lds_data[block_size1] = 0; lds_data[block_size1] = 0;
item_num = num_in_batch; item_num = n_dims;
for(size_t i = thr_idx; i < num_in_batch; i += block_size) for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
data_idx[axis] = i; if (i < n_dims)
lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size]; {
lds_data[i] = ::exp(to_hip_type(lds_data[i])); data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
...@@ -120,8 +132,7 @@ argument logsoftmax(hipStream_t stream, ...@@ -120,8 +132,7 @@ argument logsoftmax(hipStream_t stream,
auto log_batch_sum = auto log_batch_sum =
::log(to_hip_type(lds_data[block_size1])) + lds_data[block_size]; ::log(to_hip_type(lds_data[block_size1])) + lds_data[block_size];
item_num = num_in_batch; for(size_t i = thr_idx; i < n_dims; i += block_size)
for(size_t i = thr_idx; i < num_in_batch; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); size_t index = desc_data.linear(data_idx);
......
...@@ -21,7 +21,7 @@ argument softmax(hipStream_t stream, ...@@ -21,7 +21,7 @@ argument softmax(hipStream_t stream,
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; size_t n_dims = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens}; migraphx::shape batch_shape{output_shape.type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
...@@ -31,7 +31,13 @@ argument softmax(hipStream_t stream, ...@@ -31,7 +31,13 @@ argument softmax(hipStream_t stream,
hip_tensor_descriptor<n_dim> desc_data(output_shape); hip_tensor_descriptor<n_dim> desc_data(output_shape);
// use one block for items in one batch. // use one block for items in one batch.
const size_t block_size = 1024; const size_t max_block_size = 1024;
size_t block_size = 1;
while (block_size < max_block_size and block_size < n_dims)
{
block_size *= 2;
}
launch( launch(
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
...@@ -40,16 +46,21 @@ argument softmax(hipStream_t stream, ...@@ -40,16 +46,21 @@ argument softmax(hipStream_t stream,
// all data can be loaded to the lds once, so all operations are // all data can be loaded to the lds once, so all operations are
// done in lds // done in lds
MIGRAPHX_DEVICE_SHARED type lds_data[block_size + 2]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t item_num = n_dims; size_t item_num = n_dims;
size_t thread_num = (n_dims + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0]; lds_data[block_size] = input_ptr[0];
for(size_t i = thr_idx; i < n_dims; i += block_size) lds_data[block_size + 1] = 0;
for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
data_idx[axis] = i; if (i < n_dims)
lds_data[i] = input_ptr[desc_data.linear(data_idx)]; {
data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
}
__syncthreads(); __syncthreads();
...@@ -81,14 +92,15 @@ argument softmax(hipStream_t stream, ...@@ -81,14 +92,15 @@ argument softmax(hipStream_t stream,
item_num -= block_size; item_num -= block_size;
} }
const size_t block_size1 = block_size + 1;
lds_data[block_size1] = 0;
item_num = n_dims; item_num = n_dims;
for(size_t i = thr_idx; i < n_dims; i += block_size) for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
data_idx[axis] = i; if (i < n_dims)
lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size]; {
lds_data[i] = ::exp(to_hip_type(lds_data[i])); data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
}
__syncthreads(); __syncthreads();
...@@ -109,7 +121,7 @@ argument softmax(hipStream_t stream, ...@@ -109,7 +121,7 @@ argument softmax(hipStream_t stream,
if(thr_idx == 0) if(thr_idx == 0)
{ {
lds_data[block_size1] += lds_data[0]; lds_data[block_size + 1] += lds_data[0];
} }
__syncthreads(); __syncthreads();
...@@ -121,7 +133,7 @@ argument softmax(hipStream_t stream, ...@@ -121,7 +133,7 @@ argument softmax(hipStream_t stream,
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); size_t index = desc_data.linear(data_idx);
auto val = input_ptr[index] - lds_data[block_size]; auto val = input_ptr[index] - lds_data[block_size];
output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size1]; output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1];
} }
}); });
}); });
......
...@@ -598,7 +598,7 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -598,7 +598,7 @@ struct test_softmax : verify_program<test_softmax<Axis>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{migraphx::shape::float_type, {2080, 4, 1026, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -3350,7 +3350,7 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> ...@@ -3350,7 +3350,7 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{migraphx::shape::float_type, {1025, 4, 1025, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment