Unverified Commit 6416f066 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into fastsoftmax

parents 647d5dc5 97a1ed2d
......@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -25,8 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
......@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -27,13 +27,6 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -27,8 +27,6 @@
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
......
......@@ -26,16 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,7 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,8 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun,
std::size_t global,
std::size_t local,
void* kernargs,
std::size_t size)
std::size_t size,
hipEvent_t start,
hipEvent_t stop)
{
assert(global > 0);
assert(local > 0);
......@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun,
#endif
};
auto status = hipExtModuleLaunchKernel(
fun, global, 1, 1, local, 1, 1, 0, stream, nullptr, reinterpret_cast<void**>(&config));
auto status = hipExtModuleLaunchKernel(fun,
global,
1,
1,
local,
1,
1,
0,
stream,
nullptr,
reinterpret_cast<void**>(&config),
start,
stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status));
if(stop != nullptr)
{
status = hipEventSynchronize(stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to sync event: " + hip_error(status));
}
}
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
std::vector<void*> args) const
std::vector<void*> args,
hipEvent_t start,
hipEvent_t stop) const
{
assert(impl != nullptr);
void* kernargs = args.data();
std::size_t size = args.size() * sizeof(void*);
launch_kernel(impl->fun, stream, global, local, kernargs, size);
launch_kernel(impl->fun, stream, global, local, kernargs, size, start, stop);
}
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
const std::vector<kernel_argument>& args) const
const std::vector<kernel_argument>& args,
hipEvent_t start,
hipEvent_t stop) const
{
assert(impl != nullptr);
std::vector<char> kernargs = pack_args(args);
std::size_t size = kernargs.size();
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size, start, stop);
}
} // namespace gpu
......
......@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
{
return last;
}
if(!(*it == *s_it))
if(not(*it == *s_it))
{
break;
}
......
......@@ -33,49 +33,95 @@
namespace migraphx {
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \
constexpr array& operator op(const array<U, N>& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y[i]; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \
constexpr array& operator op(const array<U, N>& x) \
{ \
array_detail::array_for_each(*this, x)([](auto& sy, auto sx) { sy op sx; }); \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \
{ \
array_detail::array_for_each (*this)([&](auto& sy) { sy op x; }); \
return *this; \
} \
template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, x, y)( \
[&](auto& sz, auto sx, auto sy) { sz = sx binary_op sy; }); \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, x)([&](auto& sz, auto sx) { sz = sx binary_op y; }); \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
array_detail::array_for_each(z, y)([&](auto& sz, auto sy) { sz = x binary_op sy; }); \
return z; \
}
namespace array_detail {
template <class T>
constexpr auto is_vectorizable()
{
return not is_same<T, bool>{} and (is_fundamental<T>{} or is_same<T, half>{});
}
template <class T>
__device__ auto& array2vec(T& x)
{
using value_type = typename T::value_type;
constexpr auto size = decltype(x.size()){};
using type = vec<value_type, size>;
if constexpr(is_const<T>{})
return reinterpret_cast<const type&>(x);
else
return reinterpret_cast<type&>(x);
}
template <class T, class... Ts>
constexpr auto array_for_each(T& x, Ts&... xs)
{
MIGRAPHX_ASSERT(((x.size() == xs.size()) and ...));
return [&](auto f) {
constexpr auto size = decltype(x.size()){};
if constexpr((is_vectorizable<typename T::value_type>() or
(is_vectorizable<typename Ts::value_type>() or ...)) and
size <= 8 and size > 1 and (size % 2 == 0))
{
if(__builtin_is_constant_evaluated())
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
else
{
using vec_type = std::remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
}
}
else
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
};
}
} // namespace array_detail
template <class T, index_int N>
struct array
{
using value_type = T;
T d[N];
constexpr T& operator[](index_int i)
{
......@@ -108,18 +154,13 @@ struct array
constexpr T dot(const array& x) const
{
T result = 0;
for(index_int i = 0; i < N; i++)
result += x[i] * d[i];
return result;
auto r = x * (*this);
return r.reduce([](auto a, auto b) { return a + b; }, 0);
}
constexpr T product() const
{
T result = 1;
for(index_int i = 0; i < N; i++)
result *= d[i];
return result;
return reduce([](auto x, auto y) { return x * y; }, 1);
}
constexpr T single(index_int width = 100) const
......@@ -134,6 +175,24 @@ struct array
return result;
}
template <class F>
constexpr auto apply(F f) const
{
array<decltype(f(d[0])), N> result;
for(index_int i = 0; i < N; i++)
result[i] = f(d[i]);
return result;
}
template <class F>
constexpr auto reduce(F f, T init) const
{
T result = init;
for(index_int i = 0; i < N; i++)
result = f(result, d[i]);
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
......@@ -153,7 +212,7 @@ struct array
return true;
}
friend constexpr bool operator!=(const array& x, const array& y) { return !(x == y); }
friend constexpr bool operator!=(const array& x, const array& y) { return not(x == y); }
// This uses the product order rather than lexical order
friend constexpr bool operator<(const array& x, const array& y)
{
......@@ -201,6 +260,11 @@ struct array
}
};
template <class T, class... Ts>
constexpr array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {x, static_cast<T>(xs)...};
}
template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)>
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#ifndef MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
#define MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
namespace migraphx {
template <index_int Axis, class Output, class Input, class Start>
constexpr auto concat_slice(Output out, Input, Start)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
constexpr auto strides = get_shape_c<Output>{}.strides;
constexpr auto offset = return_c([] {
constexpr auto output_shape = get_shape_c<Output>{};
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
}
template <index_int Axis, class Input>
constexpr auto concat_ends(Input)
{
constexpr auto lens = get_shape_c<Input>{}.lens;
return _c<lens[Axis]>;
}
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
......@@ -28,9 +28,60 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline __device__ __attribute__((const)) index_int compute_global_size()
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return blockDim.x * gridDim.x; // NOLINT
#endif
}
// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else
const auto ngroup = gridDim.x; // NOLINT
#endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
struct index
{
index_int global = 0;
......@@ -38,20 +89,44 @@ struct index
index_int group = 0;
#ifdef MIGRAPHX_NGLOBAL
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; }
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const
{
static_assert(MIGRAPHX_NGLOBAL > 0, "Global size must be greater than 0");
return {};
}
#else
__device__ index_int nglobal() const
{
return blockDim.x * gridDim.x; // NOLINT
MIGRAPHX_ASSERT(compute_global_size() > 0);
return compute_global_size(); // NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; }
#ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const
{
static_assert(MIGRAPHX_NLOCAL > 0, "Local size must be greater than 0");
return {};
}
#else
__device__ index_int nlocal() const
{
return blockDim.x; // NOLINT
#ifdef MIGRAPHX_NGROUP
static_assert((MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL != 0) and (MIGRAPHX_NGROUP > 1),
"Local size should be const");
#endif
MIGRAPHX_ASSERT(compute_local_size() > 0);
return compute_local_size(); // NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> max_nlocal() const { return {}; }
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
}
#endif
template <class N, class Stride>
......@@ -63,6 +138,7 @@ struct index
template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1)
{
......
......@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(and)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(or)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(not )
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-)
......
......@@ -29,6 +29,12 @@
namespace migraphx {
template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
return a.apply([&](auto x) { return vec_reduce(x, op); });
}
template <index_int Axis,
class F,
class BinOp,
......@@ -43,23 +49,21 @@ __device__ void generic_binary_layernorm(
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
auto means =
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
y = compute(m * rsqrt(variance + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
......
......@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread];
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
......@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y;
}
#else
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()];
__shared__ type buffer[idx.max_nlocal()];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
......@@ -201,12 +202,9 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx,
op,
init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return vec_reduce(read(x[j], xs[j]...), op);
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment