Commit be5f3539 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge develop branch changes

parents 7e3bdc34 ebfe9735
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp> #include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -12,69 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,69 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{ {
auto lens = output_shape.lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; std::size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const auto* input_ptr = device_cast(input.data()); const std::size_t max_block_size = 256;
auto* output_ptr = device_cast(output.data()); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) { gs_launch(stream,
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); batch_shape.elements() * block_size,
hip_tensor_descriptor<n_dim> desc_data(output_shape); block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
// each thread is for one item in the batch using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
gs_launch(stream, batch_shape.elements())([=](auto i) { type init = lowest();
auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx; auto batch_max = block_reduce<max_block_size>(
// get max idx, max{}, init, batch_item_num, [&](auto j) __device__ {
auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j; data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max), return input[data_idx];
to_hip_type(input_ptr[desc_data.linear(data_idx)])); });
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j) auto batch_sum =
{ block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
}
auto batch_sum = output_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j; data_idx[axis] = j;
batch_sum += output_ptr[desc_data.linear(data_idx)]; auto val = input[data_idx] - batch_max;
} return ::exp(to_hip_type(val));
});
for(std::size_t j = 0; j < n_dims; ++j)
{ idx.local_stride(batch_item_num, [&](auto j) {
data_idx[axis] = j; data_idx[axis] = j;
auto idx = desc_data.linear(data_idx); auto val = input[data_idx] - batch_max;
output_ptr[idx] = output_ptr[idx] / batch_sum; output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
}
}); });
}); });
}); });
return args.back();
} }
} // namespace device } // namespace device
......
This diff is collapsed.
...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const ...@@ -12,11 +12,9 @@ shape hip_gather::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_gather::compute(context& ctx, argument hip_gather::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const shape& output_shape,
const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
} }
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -34,7 +34,7 @@ struct miopen_softmax ...@@ -34,7 +34,7 @@ struct miopen_softmax
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "gpu::softmax"; } std::string name() const { return "miopen::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment