Commit 05e81ed3 authored by charlie's avatar charlie
Browse files

Merge branch 'select_module_op' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_batch_pass

parents 89c8b52c 5de36e4a
...@@ -54,8 +54,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -54,8 +54,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# add this for roctracer dependancies # add this for roctracer dependancies
RUN pip3 install CppHeaderParser packaging==22.0 RUN pip3 install CppHeaderParser
# Workaround broken rocm packages # Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm RUN ln -s /opt/rocm-* /opt/rocm
......
...@@ -76,7 +76,7 @@ function(py_add_module NAME) ...@@ -76,7 +76,7 @@ function(py_add_module NAME)
) )
endfunction() endfunction()
set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9) set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9 3.10)
set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "") set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "")
foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS}) foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS})
list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION}) list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION})
......
...@@ -182,13 +182,13 @@ struct context ...@@ -182,13 +182,13 @@ struct context
void wait_for(any_ptr queue) void wait_for(any_ptr queue)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(std::move(queue)); (*this).private_detail_te_get_handle().wait_for(queue);
} }
void finish_on(any_ptr queue) void finish_on(any_ptr queue)
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(std::move(queue)); (*this).private_detail_te_get_handle().finish_on(queue);
} }
void finish() const void finish() const
...@@ -261,29 +261,29 @@ struct context ...@@ -261,29 +261,29 @@ struct context
template <class T> template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue) static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(std::move(queue))) -> decltype(private_detail_te_self.wait_for(queue))
{ {
private_detail_te_self.wait_for(std::move(queue)); private_detail_te_self.wait_for(queue);
} }
template <class T> template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue) static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{ {
wait_for_context(private_detail_te_self, std::move(queue)); wait_for_context(private_detail_te_self, queue);
} }
template <class T> template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue) static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(std::move(queue))) -> decltype(private_detail_te_self.finish_on(queue))
{ {
private_detail_te_self.finish_on(std::move(queue)); private_detail_te_self.finish_on(queue);
} }
template <class T> template <class T>
static void static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue) private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{ {
finish_on_context(private_detail_te_self, std::move(queue)); finish_on_context(private_detail_te_self, queue);
} }
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -335,13 +335,13 @@ struct context ...@@ -335,13 +335,13 @@ struct context
void wait_for(any_ptr queue) override void wait_for(any_ptr queue) override
{ {
private_detail_te_default_wait_for(char(0), private_detail_te_value, std::move(queue)); private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
} }
void finish_on(any_ptr queue) override void finish_on(any_ptr queue) override
{ {
private_detail_te_default_finish_on(char(0), private_detail_te_value, std::move(queue)); private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
} }
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
......
...@@ -43,7 +43,7 @@ struct select_module ...@@ -43,7 +43,7 @@ struct select_module
std::string name() const { return "select_module"; } std::string name() const { return "select_module"; }
shape compute_shape(const std::vector<shape>&, std::vector<module_ref>) const shape compute_shape(const std::vector<shape>&, const std::vector<module_ref>&) const
{ {
return shape{output_dyn_shapes}; return shape{output_dyn_shapes};
} }
...@@ -72,7 +72,7 @@ struct select_module ...@@ -72,7 +72,7 @@ struct select_module
{ {
MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes"); MIGRAPHX_THROW("SELECT_MODULE: no compatible submodules found for given input shapes");
} }
auto module_to_run = *module_iter; auto* module_to_run = *module_iter;
std::unordered_map<std::string, argument> params; std::unordered_map<std::string, argument> params;
// add input parameters // add input parameters
......
...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256)
algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx { namespace migraphx {
...@@ -135,42 +136,100 @@ struct index ...@@ -135,42 +136,100 @@ struct index
return (n - _c<1>) / stride + _c<1>; return (n - _c<1>) / stride + _c<1>;
} }
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return max_stride_iterations(n, nglobal());
}
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal());
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{
return f(i, d);
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i))
{
return f(i);
}
template <class F, class N, class Stride>
static constexpr void for_stride_loop_unroll(index_int start, N n, Stride stride, F f)
{
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
}
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride_loop(index_int start, N n, Stride stride, F f)
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
max_stride_iterations(n, stride) == 1) {
if constexpr(max_stride_iterations(n, stride) == 1)
{ {
if constexpr(stride > n) if constexpr(stride > n)
{ {
if(start < n) if(start < n)
f(start); invoke_loop(f, start, _c<0>);
} }
else else
{ {
f(start); invoke_loop(f, start, _c<0>);
} }
} }
else else if constexpr(Unroll)
{ {
for(index_int i = start; i < n; i += stride) MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
{
for_stride_loop_unroll(start, n, stride, f);
}
}
else
{ {
f(i); for_stride_loop(start, n, stride, f);
}
} }
else
{
for_stride_loop(start, n, stride, f);
} }
} }
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
}; };
......
...@@ -46,28 +46,27 @@ template <index_int Axis, ...@@ -46,28 +46,27 @@ template <index_int Axis,
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon) // m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...); y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input, inputs...);
}); });
} }
......
...@@ -66,13 +66,22 @@ struct convert_to ...@@ -66,13 +66,22 @@ struct convert_to
} }
}; };
template <index_int N>
struct mean struct mean
{ {
index_int item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x) const
{
using type = vec_type<T>;
if constexpr(is_floating_point<type>{})
{ {
return x / static_cast<T>(item_num); constexpr type d = 1.0 / N;
return x * d;
}
else
{
return x / static_cast<type>(N);
}
} }
}; };
......
...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread; const auto ldsidx = idx.local / lanes_per_thread;
...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F> ...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i) ...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace reduce { namespace reduce {
struct inner_storage_tag
{
};
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F>
struct storage_access : F
{
using type = R;
};
template <class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
}
template <class Slicer, class F> template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f) constexpr auto sliced(Slicer slicer, F f)
{ {
...@@ -191,42 +210,100 @@ constexpr auto compute_reduce_axis() ...@@ -191,42 +210,100 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis> template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>()); using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block template <class Derived>
struct reducer_base
{ {
template <class Slicer> template <class T>
struct reducer __device__ auto make_inner_slice(T x) const
{ {
index idx; if constexpr(is_inner_storage<T>{})
Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return x;
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { }
return vec_reduce(read(x[j], xs[j]...), op); else
}); {
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
}); });
} }
}
template <class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template <class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_inner_storage<T>{})
{
return x.rsize();
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return t.size();
}
}
template <class F> template <class F>
__device__ void outer(F f) const __device__ auto inner_sliced(F f) const
{ {
if(idx.local == 0) return [=](auto&&... xs) { return f(get_size(xs...), make_inner_slice(xs)...); };
f();
} }
template <class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return this->inner_sliced([=](auto n, auto&&... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
}
else
{
return derived.template inner_impl<result_type>(f, n, xs...);
}
}); });
} }
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
template <class Op, class T>
__device__ auto reduce(Op op, T init) const
{
return this->reduce(op, init, op::id{});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slice(Input{})); auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -234,12 +311,69 @@ struct block ...@@ -234,12 +311,69 @@ struct block
else else
return relements; return relements;
} }
};
struct block
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{
using type = T;
array<T, N> arr;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
idx.local_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -254,56 +388,143 @@ struct block ...@@ -254,56 +388,143 @@ struct block
} }
}; };
struct lane struct block_large
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class Size, class F>
struct inner_storage : inner_storage_tag
{ {
return sliced(slice, [=](auto x, auto... xs) { using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
using type = typename decltype(x)::type; F f;
type r = init; constexpr Size rsize() const { return {}; }
for(index_int j = 0; j < x.get_shape().elements(); j++) template <class U, class V>
constexpr auto operator()(U j, V d) const
{ {
r = op(r, read(x[j], xs[j]...)); return f(j, d);
} }
return r; };
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, index_int{n}, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
}); });
} }
template <class F> template <class F>
__device__ void outer(F f) const __device__ void outer(F f) const
{ {
if(idx.local == 0)
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { idx.local_stride(index_int{n}, [&](auto j, auto d) { f(xs(j, d)...); });
for(index_int j = 0; j < x.get_shape().elements(); j++) }
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{ {
f(x[j], xs[j]...); return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
}); });
} }
};
template <class Input> struct lane
constexpr auto elements() const {
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{ {
using reduce_type = decltype(slice(Input{})); index idx;
return get_shape_c<reduce_type>{}.elements(); Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
} }
}; };
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init;
for(index_int j = 0; j < n; j++)
{
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
}
return r;
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
for(index_int j = 0; j < n; j++)
{
f(xs(j, _c<0>)...);
}
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -318,6 +539,26 @@ struct lane ...@@ -318,6 +539,26 @@ struct lane
} }
}; };
// TODO: Remove these in the future when they can be selected in the compiler class
template <index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{} * 256)
return block{};
else
return block_large{};
}
template <index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce } // namespace reduce
template <class Algo, template <class Algo,
......
...@@ -30,18 +30,20 @@ ...@@ -30,18 +30,20 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
#else #else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input); const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
return migraphx::convert<float>(migraphx::exp(x - c)); auto batch_sum =
})(input); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input); r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in);
}); });
} }
......
...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible); ...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T> template <class T>
struct remove_reference struct remove_reference
{ {
...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*> ...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T> template <class T>
using add_pointer_t = typename add_pointer<T>::type; using add_pointer_t = typename add_pointer<T>::type;
template <class T>
struct is_void : is_same<void, remove_cv_t<T>>
{
};
template <class... Ts> template <class... Ts>
struct common_type; struct common_type;
......
...@@ -369,7 +369,7 @@ struct miopen_apply ...@@ -369,7 +369,7 @@ struct miopen_apply
apply_map.emplace("select_module", [=](instruction_ref ins) { apply_map.emplace("select_module", [=](instruction_ref ins) {
std::vector<instruction_ref> inputs = ins->inputs(); std::vector<instruction_ref> inputs = ins->inputs();
auto mod_args = ins->module_inputs(); auto mod_args = ins->module_inputs();
for(auto smod : mod_args) for(auto* smod : mod_args)
{ {
smod->use_local_alloc = true; smod->use_local_alloc = true;
auto last_ins = std::prev(smod->end()); auto last_ins = std::prev(smod->end());
......
...@@ -7285,7 +7285,7 @@ TEST_CASE(select_module_add_test) ...@@ -7285,7 +7285,7 @@ TEST_CASE(select_module_add_test)
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}}; migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
...@@ -7329,7 +7329,7 @@ TEST_CASE(select_module_reduce_test0) ...@@ -7329,7 +7329,7 @@ TEST_CASE(select_module_reduce_test0)
migraphx::program p; migraphx::program p;
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}}; migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
...@@ -7375,7 +7375,7 @@ TEST_CASE(select_module_reduce_test1) ...@@ -7375,7 +7375,7 @@ TEST_CASE(select_module_reduce_test1)
migraphx::program p; migraphx::program p;
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}}; migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
......
...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2> ...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
return p; return p;
}; };
}; };
struct test_large_reduce_mean : verify_program<test_large_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 256 * 256 * 16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
...@@ -37,7 +37,7 @@ struct test_select_module_add : verify_program<test_select_module_add> ...@@ -37,7 +37,7 @@ struct test_select_module_add : verify_program<test_select_module_add>
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}}; migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
......
...@@ -34,8 +34,8 @@ struct test_select_module_reduce : verify_program<test_select_module_reduce> ...@@ -34,8 +34,8 @@ struct test_select_module_reduce : verify_program<test_select_module_reduce>
migraphx::program p; migraphx::program p;
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}}; migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
auto reduce_ins = auto reduce_ins =
......
...@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX" ...@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX"
rbuild prepare -d $PREFIX -s develop rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests # install onnx package for unit tests
pip3 install onnx==1.10.0 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==16.8 pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0
# pin version of protobuf in Python for onnx runtime unit tests # pin version of protobuf in Python for onnx runtime unit tests
pip3 install protobuf==3.20.0 pip3 install protobuf==3.20.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment