"vscode:/vscode.git/clone" did not exist on "41a94e13a633ba61e0c7f3fa686440570fa1bdb8"
Unverified Commit b73427c9 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into fix_for_multiconfig_generators

parents 55e635e5 4c059fa3
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
This diff is collapsed.
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
#include <type_traits>
using microseconds = std::chrono::duration<double, std::micro>; using microseconds = std::chrono::duration<double, std::micro>;
...@@ -34,6 +37,20 @@ namespace migraphx { ...@@ -34,6 +37,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};
// Convert rocBLAS datatypes to equivalent Migraphx data types // Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type: case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type: case shape::tuple_type:
case shape::bool_type: case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
...@@ -183,12 +200,17 @@ struct gemm_impl ...@@ -183,12 +200,17 @@ struct gemm_impl
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
compute_type = output_type; compute_type = rb_compute_type{output_type};
if(compute_fp32) if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}
auto a_lens = input_shapes[0].lens(); auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens(); auto b_lens = input_shapes[1].lens();
...@@ -216,6 +238,34 @@ struct gemm_impl ...@@ -216,6 +238,34 @@ struct gemm_impl
} }
void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
else
#endif
{ {
if(strided_batched) if(strided_batched)
{ {
...@@ -236,6 +286,7 @@ struct gemm_impl ...@@ -236,6 +286,7 @@ struct gemm_impl
gemm_flags); gemm_flags);
} }
} }
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
...@@ -331,7 +382,6 @@ struct gemm_impl ...@@ -331,7 +382,6 @@ struct gemm_impl
num_matrices, num_matrices,
compute_type); compute_type);
} }
/** /**
* Helper method to create that subset of a long rocBLAS argument list that is common * Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls. * to multiple "gemm_ex..." calls.
...@@ -366,6 +416,7 @@ struct gemm_impl ...@@ -366,6 +416,7 @@ struct gemm_impl
ldd, ldd,
compute_type); compute_type);
} }
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/** /**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index * Find best rocBLAS solution: Get list of solutions and try them all, returning the index
...@@ -481,8 +532,8 @@ struct gemm_impl ...@@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int b_stride = 0; rocblas_int b_stride = 0;
rocblas_int c_stride = 0; rocblas_int c_stride = 0;
rocblas_int d_stride = 0; rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true; bool strided_batched = true;
bool is_3inputs = true; bool is_3inputs = true;
......
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -34,6 +34,7 @@ struct module_pass_manager; ...@@ -34,6 +34,7 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
......
...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm ...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
} }
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -40,6 +40,8 @@ struct context; ...@@ -40,6 +40,8 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();
MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,41 +21,58 @@ ...@@ -21,41 +21,58 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP #ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP #define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/argument.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; template <typename Derived>
struct scatter_compiler : compiler<Derived>
struct hip_pad
{ {
op::pad op; compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
template <class Self, class F>
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); const auto inputs =
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()});
hip_compile_options options;
options.set_launch_params(op.to_value(), compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = derived().get_kernel_name(op);
options.virtual_inputs = inputs;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options.params += "-Wno-float-equal";
const auto src = derived().make_interpolated_string(op);
return prepend_copy_data_to_output(compile_hip_code_object(src, options));
} }
std::string name() const { return "gpu::pad"; } compiler_replace prepend_copy_data_to_output(const operation& co) const
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
} }
std::string get_kernel_name(const operation& op) const { return op.name() + "_kernel"; }
const Derived& derived() const { return static_cast<const Derived&>(*this); }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -21,11 +21,7 @@ ...@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compiler.hpp> #include "scatter.hpp"
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void* ...@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"; )__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler> struct scatternd_compiler : scatter_compiler<scatternd_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"scatternd_none", "scatternd_add", "scatternd_mul"}; return {
"scatternd_none", "scatternd_add", "scatternd_mul", "scatternd_min", "scatternd_max"};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const std::string make_interpolated_string(const operation& op) const
{ {
hip_compile_options options; const auto reduction = op.name().substr(std::char_traits<char>::length("scatternd_"));
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); return interpolate_string(scatternd_kernel, {{"reduction", "assign_" + reduction}});
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const std::string get_kernel_name(const operation&) const { return "scatternd_kernel"; }
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(
ctx,
to_shapes(std::vector<instruction_ref>{ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& co) const
{
return {co, [](module& m, instruction_ref ins, const operation& op) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
}};
}
}; };
} // namespace gpu } // namespace gpu
......
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
return __builtin_bit_cast(To, fr);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x) ...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y; return y;
} }
template <unsigned int DppCtrl, template <class T, class F>
unsigned int RowMask = 0xf, __device__ T dpp_op(T& x, F f)
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{ {
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x) ...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = f(input.reg[i]);
} }
return output.data; return output.data;
} }
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
This diff is collapsed.
...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, ...@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto indices_shape_lens = indices_shape.lens; auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens; auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back(); auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = size_t num_slices =
accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{}); accumulate(indices_shape_lens.begin(), indices_shape_lens.end() - 1, 1, op::product{});
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims, size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(), data_shape_lens.end(),
1, 1,
op::product{}); op::product{});
const std::size_t num_batches = const size_t num_batches =
accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{}); accumulate(data_shape_lens.begin(), data_shape_lens.begin() + batch_dims, 1, op::product{});
const std::size_t data_batch_stride = const size_t data_batch_stride =
accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{}); accumulate(data_shape_lens.begin() + batch_dims, data_shape_lens.end(), 1, op::product{});
const auto num_slices_per_batch = num_slices / num_batches; const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data(); const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size; const size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch; const size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims); auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0; size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx) for(size_t idx = 0; idx < num_slice_dims; ++idx)
{ {
int64_t index = slice_indices[idx]; int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx; const size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx]; const auto input_dim = data_shape_lens[input_dim_idx];
MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and MIGRAPHX_ASSERT(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim)); index < static_cast<int64_t>(input_dim));
if(index < 0) if(index < 0)
index += input_dim; index += input_dim;
std::size_t size_from_slice_dims = size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1, accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims, data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size, slice_size,
......
...@@ -52,14 +52,17 @@ __device__ void generic_binary_layernorm( ...@@ -52,14 +52,17 @@ __device__ void generic_binary_layernorm(
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements}; constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r); auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{},
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
[&](auto x) {
auto x_out = x * relements_r; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing higher values // dividing x by sqrt(relements) before squaring allows computing
// before overflow in low precision // higher values before overflow in low precision
auto x2_sqrt = x * relements_rsqrt; auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt); return make_array(x_out, x2_sqrt * x2_sqrt);
})(input); })(input);
...@@ -67,7 +70,7 @@ __device__ void generic_binary_layernorm( ...@@ -67,7 +70,7 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = implicit_conversion(eps);
r.inner([&](auto& y, auto x, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x; auto m = x - mean_x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment