Unverified Commit 51f264a6 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #293 from ROCmSoftwarePlatform/opt_log_softmax

Opt log softmax
parents 5af04e80 c98b06a1
...@@ -99,6 +99,8 @@ struct shape ...@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
......
...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const ...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return result; return result;
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const
{
assert(this->standard());
std::vector<std::size_t> indices(lens().size());
std::transform(strides().begin(),
strides().end(),
lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
return indices;
}
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const bool shape::transposed() const
......
...@@ -539,18 +539,11 @@ struct cpu_softmax ...@@ -539,18 +539,11 @@ struct cpu_softmax
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -558,26 +551,33 @@ struct cpu_softmax ...@@ -558,26 +551,33 @@ struct cpu_softmax
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis); par_for(batch_shape.elements(), [&](auto i) {
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); auto idx = batch_shape.multi(i);
}); for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) = idx[op.axis] = j;
std::exp(input(idx.begin(), idx.end()) - batch_max[index]); std::size_t index = output_shape.index(idx);
}); output[index] = std::exp(input[index] - batch_max[i]);
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
}); }
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) /= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i];
}
}); });
}); });
...@@ -597,49 +597,50 @@ struct cpu_logsoftmax ...@@ -597,49 +597,50 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, const shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
// use a parallel implementation to acheive better performance
// one thread for one batch
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) { par_for(batch_shape.elements(), [&](auto i) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto idx = batch_shape.multi(i);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; for(std::size_t j = 0; j < n_dims; ++j)
}); {
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); std::size_t index = output_shape.index(idx);
}); output[index] = input[index] - batch_max[i];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
}
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) -= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i];
}
}); });
}); });
......
...@@ -73,7 +73,7 @@ __host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i) ...@@ -73,7 +73,7 @@ __host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i)
inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
......
...@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x) ...@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x)
} }
template <class T> template <class T>
T to_hip_type(T x) __device__ __host__ T to_hip_type(T x)
{ {
return x; return x;
} }
// Hip doens't support __fp16 // Hip doens't support __fp16
inline float to_hip_type(gpu_half x) { return x; } inline __device__ __host__ float to_hip_type(gpu_half x) { return x; }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument result, argument arg, int axis) void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens();
auto lens = result.get_shape().lens(); auto batch_lens = lens;
auto num_in_batch = lens[axis]; std::size_t batch_item_num = lens[axis];
auto batch_lens = lens; batch_lens[axis] = 1;
batch_lens[axis] = 1; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
// each thread is for one item in the batch const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream,
auto batch_idx = batch.multi(i); batch_shape.elements() * block_size,
auto data_idx = batch_idx; block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
// get max using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_max = input[batch_idx]; type init = lowest();
for(std::size_t j = 1; j < num_in_batch; ++j)
{ auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx])); data_idx[axis] = j;
} return input[data_idx];
});
for(std::size_t j = 0; j < num_in_batch; ++j)
{ auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
auto log_batch_sum = ::log(to_hip_type(batch_sum)) + batch_max;
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = input[data_idx] - batch_max; output[data_idx] = input[data_idx] - log_batch_sum;
} });
auto batch_sum = ::exp(to_hip_type(output[batch_idx]));
for(std::size_t j = 1; j < num_in_batch; ++j)
{
data_idx[axis] = j;
batch_sum += ::exp(to_hip_type(output[data_idx]));
}
batch_sum = ::log(to_hip_type(batch_sum));
for(std::size_t j = 0; j < num_in_batch; ++j)
{
data_idx[axis] = j;
output[data_idx] -= batch_sum;
}
}); });
}); });
return result;
} }
} // namespace device } // namespace device
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp> #include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument result, argument arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const std::size_t max_block_size = 256;
// each thread is for one item in the batch const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream,
auto batch_idx = batch.multi(i); batch_shape.elements() * block_size,
auto data_idx = batch_idx; block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
// get max using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_max = input[batch_idx]; type init = lowest();
for(std::size_t j = 1; j < n_dims; ++j)
{ auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx])); data_idx[axis] = j;
} return input[data_idx];
});
for(std::size_t j = 0; j < n_dims; ++j)
{ auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j; data_idx[axis] = j;
output[data_idx] = exp(to_hip_type(input[data_idx] - batch_max)); auto val = input[data_idx] - batch_max;
} output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
auto batch_sum = output[batch_idx];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_sum += output[data_idx];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
output[data_idx] = output[data_idx] / batch_sum;
}
}); });
}); });
return result;
} }
} // namespace device } // namespace device
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument result, argument arg, int axis); void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument result, argument arg, int axis); void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -18,7 +18,8 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -18,7 +18,8 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
argument argument
hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return device::logsoftmax(ctx.get_stream().get(), args[1], args[0], op.axis); device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -39,7 +39,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -39,7 +39,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
return device::softmax(ctx.get_stream().get(), args[1], args[0], op.axis); device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -592,13 +592,13 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -592,13 +592,13 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis>> struct test_softmax : verify_program<test_softmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -606,10 +606,14 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -606,10 +606,14 @@ struct test_softmax : verify_program<test_softmax<Axis>>
} }
}; };
template struct test_softmax<0>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<2>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3>; template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -3345,12 +3349,12 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3345,12 +3349,12 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
}; };
template <int Axis> template <int Axis>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3358,18 +3362,15 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> ...@@ -3358,18 +3362,15 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
} }
}; };
template struct test_logsoftmax<0>; template struct test_logsoftmax_1<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3377,7 +3378,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3377,7 +3378,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment