"...lm-evaluation-harness.git" did not exist on "baa34a5586385cdf86432a1bb1c1f7a48f0df2ae"
Commit 8f568801 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into layout-nhwc

parents 7393cf1e f7d987ba
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#define MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sub.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sub : binary_device<hip_sub, device::sub>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_TAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_TAN_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/tan.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_tan : unary_device<hip_tan, device::tan>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_TANH_HPP
#define MIGRAPHX_GUARD_RTGLIB_TANH_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/tanh.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_tanh : unary_device<hip_tanh, device::tanh>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_UNARY_NOT_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNARY_NOT_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/unary_not.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_unary_not : unary_device<hip_unary_not, device::unary_not>
{
std::string name() const { return "gpu::not"; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/where.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -74,7 +74,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -74,7 +74,7 @@ struct concat_compiler : compiler<concat_compiler>
options.output = inputs.back(); options.output = inputs.back();
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs); auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
......
...@@ -65,7 +65,7 @@ struct gathernd_compiler : compiler<gathernd_compiler> ...@@ -65,7 +65,7 @@ struct gathernd_compiler : compiler<gathernd_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
auto out_s = inputs.back(); const auto& out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs; options.inputs = inputs;
options.output = out_s; options.output = out_s;
......
...@@ -50,9 +50,8 @@ ${preamble} ...@@ -50,9 +50,8 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
}); });
} }
...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(axis == faxis) if(axis == faxis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
...@@ -90,16 +88,18 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -90,16 +88,18 @@ struct layernorm_compiler : compiler<layernorm_compiler>
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel"); options.kernel_name = v.get("kernel", "layernorm_kernel");
auto eps = v.get("epsilon", 1e-12f);
auto src = interpolate_string(layernorm_kernel, auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}}); {"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel"); options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1) if(options.virtual_inputs.back().lens()[faxis] == 1)
{ {
vec = vectorize::elements(faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
......
...@@ -32,6 +32,8 @@ namespace migraphx { ...@@ -32,6 +32,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_FAST_SOFTMAX)
using namespace migraphx::gpu::gen; // NOLINT using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__( static const char* const softmax_kernel = R"__migraphx__(
...@@ -69,7 +71,7 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -69,7 +71,7 @@ struct softmax_compiler : compiler<softmax_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(faxis == axis) if(faxis == axis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
...@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -81,6 +83,9 @@ struct softmax_compiler : compiler<softmax_compiler>
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "softmax_kernel"; options.kernel_name = "softmax_kernel";
if(enabled(MIGRAPHX_USE_FAST_SOFTMAX{}))
options.params = "-DMIGRAPHX_USE_FAST_SOFTMAX";
auto src = interpolate_string( auto src = interpolate_string(
softmax_kernel, softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}}); {{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
......
...@@ -33,49 +33,95 @@ ...@@ -33,49 +33,95 @@
namespace migraphx { namespace migraphx {
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \ #define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \ template <class U> \
constexpr array& operator op(const array<U, N>& x) \ constexpr array& operator op(const array<U, N>& x) \
{ \ { \
for(index_int i = 0; i < N; i++) \ array_detail::array_for_each(*this, x)([](auto& sy, auto sx) { sy op sx; }); \
d[i] op x[i]; \ return *this; \
return *this; \ } \
} \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ constexpr array& operator op(const U& x) \
constexpr array& operator op(const U& x) \ { \
{ \ array_detail::array_for_each (*this)([&](auto& sy) { sy op x; }); \
for(index_int i = 0; i < N; i++) \ return *this; \
d[i] op x; \ } \
return *this; \ template <class U> \
} \ friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
template <class U> \ { \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \ array<decltype(T {} binary_op U{}), N> z{}; \
{ \ array_detail::array_for_each(z, x, y)( \
array<decltype(T {} binary_op U{}), N> z{}; \ [&](auto& sz, auto sx, auto sy) { sz = sx binary_op sy; }); \
for(index_int i = 0; i < N; i++) \ return z; \
z[i] = x[i] binary_op y[i]; \ } \
return z; \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
} \ friend constexpr auto operator binary_op(const array& x, const U& y) \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ { \
friend constexpr auto operator binary_op(const array& x, const U& y) \ array<decltype(T {} binary_op U{}), N> z{}; \
{ \ array_detail::array_for_each(z, x)([&](auto& sz, auto sx) { sz = sx binary_op y; }); \
array<decltype(T {} binary_op U{}), N> z{}; \ return z; \
for(index_int i = 0; i < N; i++) \ } \
z[i] = x[i] binary_op y; \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
return z; \ friend constexpr auto operator binary_op(const U& x, const array& y) \
} \ { \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ array<decltype(T {} binary_op U{}), N> z{}; \
friend constexpr auto operator binary_op(const U& x, const array& y) \ array_detail::array_for_each(z, y)([&](auto& sz, auto sy) { sz = x binary_op sy; }); \
{ \ return z; \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
} }
namespace array_detail {
template <class T>
constexpr auto is_vectorizable()
{
return not is_same<T, bool>{} and (is_fundamental<T>{} or is_same<T, half>{});
}
template <class T>
__device__ auto& array2vec(T& x)
{
using value_type = typename T::value_type;
constexpr auto size = decltype(x.size()){};
using type = vec<value_type, size>;
if constexpr(is_const<T>{})
return reinterpret_cast<const type&>(x);
else
return reinterpret_cast<type&>(x);
}
template <class T, class... Ts>
constexpr auto array_for_each(T& x, Ts&... xs)
{
MIGRAPHX_ASSERT(((x.size() == xs.size()) and ...));
return [&](auto f) {
constexpr auto size = decltype(x.size()){};
if constexpr((is_vectorizable<typename T::value_type>() or
(is_vectorizable<typename Ts::value_type>() or ...)) and
size <= 8 and size > 1 and (size % 2 == 0))
{
if(__builtin_is_constant_evaluated())
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
else
{
using vec_type = std::remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
}
}
else
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
};
}
} // namespace array_detail
template <class T, index_int N> template <class T, index_int N>
struct array struct array
{ {
using value_type = T;
T d[N]; T d[N];
constexpr T& operator[](index_int i) constexpr T& operator[](index_int i)
{ {
...@@ -108,18 +154,13 @@ struct array ...@@ -108,18 +154,13 @@ struct array
constexpr T dot(const array& x) const constexpr T dot(const array& x) const
{ {
T result = 0; auto r = x * (*this);
for(index_int i = 0; i < N; i++) return r.reduce([](auto a, auto b) { return a + b; }, 0);
result += x[i] * d[i];
return result;
} }
constexpr T product() const constexpr T product() const
{ {
T result = 1; return reduce([](auto x, auto y) { return x * y; }, 1);
for(index_int i = 0; i < N; i++)
result *= d[i];
return result;
} }
constexpr T single(index_int width = 100) const constexpr T single(index_int width = 100) const
...@@ -134,6 +175,24 @@ struct array ...@@ -134,6 +175,24 @@ struct array
return result; return result;
} }
template <class F>
constexpr auto apply(F f) const
{
array<decltype(f(d[0])), N> result;
for(index_int i = 0; i < N; i++)
result[i] = f(d[i]);
return result;
}
template <class F>
constexpr auto reduce(F f, T init) const
{
T result = init;
for(index_int i = 0; i < N; i++)
result = f(result, d[i]);
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +) MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -) MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *) MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
...@@ -201,6 +260,11 @@ struct array ...@@ -201,6 +260,11 @@ struct array
} }
}; };
template <class T, class... Ts>
constexpr array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {x, static_cast<T>(xs)...};
}
template <class T, T... Xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
......
...@@ -28,9 +28,60 @@ ...@@ -28,9 +28,60 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
#define MIGRAPHX_NGROUP ((MIGRAPHX_NGLOBAL + MIGRAPHX_NLOCAL - 1) / MIGRAPHX_NLOCAL)
#endif
inline __device__ __attribute__((const)) index_int compute_global_size()
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return blockDim.x * gridDim.x; // NOLINT
#endif
}
// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else
const auto ngroup = gridDim.x; // NOLINT
#endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
struct index struct index
{ {
index_int global = 0; index_int global = 0;
...@@ -38,20 +89,44 @@ struct index ...@@ -38,20 +89,44 @@ struct index
index_int group = 0; index_int group = 0;
#ifdef MIGRAPHX_NGLOBAL #ifdef MIGRAPHX_NGLOBAL
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; } constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const
{
static_assert(MIGRAPHX_NGLOBAL > 0, "Global size must be greater than 0");
return {};
}
#else #else
__device__ index_int nglobal() const __device__ index_int nglobal() const
{ {
return blockDim.x * gridDim.x; // NOLINT MIGRAPHX_ASSERT(compute_global_size() > 0);
return compute_global_size(); // NOLINT
} }
#endif #endif
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; } constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const
{
static_assert(MIGRAPHX_NLOCAL > 0, "Local size must be greater than 0");
return {};
}
#else #else
__device__ index_int nlocal() const __device__ index_int nlocal() const
{ {
return blockDim.x; // NOLINT #ifdef MIGRAPHX_NGROUP
static_assert((MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL != 0) and (MIGRAPHX_NGROUP > 1),
"Local size should be const");
#endif
MIGRAPHX_ASSERT(compute_local_size() > 0);
return compute_local_size(); // NOLINT
}
#endif
#ifdef MIGRAPHX_NLOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> max_nlocal() const { return {}; }
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
} }
#endif #endif
template <class N, class Stride> template <class N, class Stride>
...@@ -63,6 +138,7 @@ struct index ...@@ -63,6 +138,7 @@ struct index
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1) max_stride_iterations(n, stride) == 1)
{ {
......
...@@ -29,6 +29,12 @@ ...@@ -29,6 +29,12 @@
namespace migraphx { namespace migraphx {
template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
return a.apply([&](auto x) { return vec_reduce(x, op); });
}
template <index_int Axis, template <index_int Axis,
class F, class F,
class BinOp, class BinOp,
...@@ -37,46 +43,46 @@ template <index_int Axis, ...@@ -37,46 +43,46 @@ template <index_int Axis,
class Input2, class Input2,
class... Inputs> class... Inputs>
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) { auto means =
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements}; auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input1, input2);
};
// mean(x) auto mean_x = means[0];
auto mean_x = mean(op); auto mean_x2 = means[1];
// mean(m ^ 2) auto variance = mean_x2 - (mean_x * mean_x);
auto mean_m2 = mean([&](auto x1, auto x2) { value_type eps_val = eps; // implicit conversion for eps
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x; auto x = op(x1, x2);
// m * rsqrt(mean(m ^ 2) + 1e-12) auto m = x - mean_x;
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
template <index_int Axis, class F, class Output, class Input, class... Inputs> template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs) __device__ void layernorm(F compute, float eps, Output output, Input input, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...); compute, [](auto x, auto) { return x; }, eps, output, input, input, inputs...);
} }
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs> template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void __device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs) add_layernorm(F compute, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>( generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...); compute, [](auto x1, auto x2) { return x1 + x2; }, eps, output, input1, input2, inputs...);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor) ...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin) MIGRAPHX_DEVICE_MATH(sin, ::sin)
...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh) ...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan) MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh) MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH(fmod, ::fmod)
// Float overloads // Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) ...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions // Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) ...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan) MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt) MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin) MIGRAPHX_DEVICE_MATH_VEC(sin)
......
...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max) ...@@ -94,16 +94,17 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -123,12 +124,12 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y; return y;
} }
#else #else
template <class Op, class T, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x; buffer[idx.local] = x;
...@@ -196,17 +197,14 @@ struct block ...@@ -196,17 +197,14 @@ struct block
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
op, return vec_reduce(read(x[j], xs[j]...), op);
init, });
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
...@@ -220,7 +218,7 @@ struct block ...@@ -220,7 +218,7 @@ struct block
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
...@@ -228,7 +226,7 @@ struct block ...@@ -228,7 +226,7 @@ struct block
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type; using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1) if constexpr(vec_size<value_type>() > 1)
...@@ -262,11 +260,11 @@ struct lane ...@@ -262,11 +260,11 @@ struct lane
struct reducer struct reducer
{ {
index idx; index idx;
Slicer slicer; Slicer slice;
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
using type = typename decltype(x)::type; using type = typename decltype(x)::type;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
...@@ -286,7 +284,7 @@ struct lane ...@@ -286,7 +284,7 @@ struct lane
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slice, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(x[j], xs[j]...);
...@@ -297,7 +295,7 @@ struct lane ...@@ -297,7 +295,7 @@ struct lane
template <class Input> template <class Input>
constexpr auto elements() const constexpr auto elements() const
{ {
using reduce_type = decltype(slicer(Input{})); using reduce_type = decltype(slice(Input{}));
return get_shape_c<reduce_type>{}.elements(); return get_shape_c<reduce_type>{}.elements();
} }
}; };
......
...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output> ...@@ -33,11 +33,15 @@ template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); #ifdef MIGRAPHX_USE_FAST_SOFTMAX
auto batch_sum = const auto c = vec_at(r.slice(input)[0], 0);
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); #else
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
input); #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::convert<float>(migraphx::exp(x - c));
})(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input);
}); });
} }
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
...@@ -81,77 +83,21 @@ struct miopen_apply ...@@ -81,77 +83,21 @@ struct miopen_apply
(void)i; (void)i;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 auto& ctx = get_context();
auto& ctx = get_context(); int8_x4_format = get_int8_x4_format(ctx);
const auto device_name = trim(split_string(get_device_name(), ':').front()); compute_fp32 = get_compute_fp32_flag();
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("acos");
add_generic_op("acosh");
add_generic_op("add");
add_generic_op("asin");
add_generic_op("asinh");
add_generic_op("atan");
add_generic_op("atanh");
add_generic_op("ceil");
add_generic_op("contiguous"); add_generic_op("contiguous");
add_generic_op("cos");
add_generic_op("cosh");
add_generic_op("div");
add_generic_op("equal");
add_generic_op("erf");
add_generic_op("exp");
add_generic_op("floor");
add_generic_op("greater");
add_generic_op("less");
add_generic_op("log");
add_generic_op("logical_and");
add_generic_op("logical_or");
add_generic_op("logical_xor");
add_generic_op("max");
add_generic_op("min");
add_generic_op("mul");
add_generic_op("not");
add_generic_op("pow");
add_generic_op("prelu");
add_generic_op("recip");
add_generic_op("relu");
add_generic_op("round");
add_generic_op("rsqrt");
add_generic_op("sigmoid");
add_generic_op("sign");
add_generic_op("sin");
add_generic_op("sinh");
add_generic_op("sqdiff");
add_generic_op("sqrt");
add_generic_op("sub");
add_generic_op("tan");
add_generic_op("tanh");
add_generic_op("where");
add_extend_op("abs");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("clip");
add_extend_op("convert");
add_extend_op("elu"); add_extend_op("elu");
add_extend_op("gather"); add_extend_op("gather");
add_extend_op("leaky_relu"); add_extend_op("leaky_relu");
...@@ -227,7 +173,8 @@ struct miopen_apply ...@@ -227,7 +173,8 @@ struct miopen_apply
init(); init();
for(auto it = mod->begin(); it != mod->end(); it++) for(auto it = mod->begin(); it != mod->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
auto attrs = it->get_operator().attributes();
if(apply_map.count(it->name()) > 0) if(apply_map.count(it->name()) > 0)
{ {
check_shape(s, apply_map.at(it->name())(it)); check_shape(s, apply_map.at(it->name())(it));
...@@ -236,11 +183,37 @@ struct miopen_apply ...@@ -236,11 +183,37 @@ struct miopen_apply
{ {
check_shape(s, insert_precompile_op(it)); check_shape(s, insert_precompile_op(it));
} }
else if(attrs.contains("target"))
{
check_shape(s, insert_custom_op(it, attrs));
}
} }
copy_params(); copy_params();
} }
instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const
{
const auto& custom_op = ins->get_operator();
if(attrs.at("target") == "cpu")
{
auto s = ins->get_shape();
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
auto output = inputs.back();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, custom_op, cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
}
return ins;
}
instruction_ref insert_precompile_op(instruction_ref ins) const instruction_ref insert_precompile_op(instruction_ref ins) const
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
...@@ -35,6 +35,12 @@ namespace { ...@@ -35,6 +35,12 @@ namespace {
template <class Derived, std::size_t N> template <class Derived, std::size_t N>
struct layernorm_base struct layernorm_base
{ {
float epsilon = 1e-12f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"));
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
...@@ -62,6 +68,7 @@ struct layernorm_base ...@@ -62,6 +68,7 @@ struct layernorm_base
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 0>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
...@@ -80,8 +87,9 @@ struct find_layernorm ...@@ -80,8 +87,9 @@ struct find_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, layernorm{}, x_ins); m.replace_instruction(ins, layernorm{eps}, x_ins);
} }
}; };
...@@ -96,8 +104,9 @@ struct find_add_layernorm ...@@ -96,8 +104,9 @@ struct find_add_layernorm
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto eps = r.instructions["eps"]->eval().at<float>();
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{eps}, add_ins->inputs());
} }
}; };
} // namespace } // namespace
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/quant_convolution.hpp> #include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment