Commit 5ec8f913 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

Merge branch 'develop' into simplify_1_mul_div_ops

parents 32d69e8e d78bcdfb
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(faxis, inputs);
}
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if(op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -26,16 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -65,18 +56,6 @@ __global__ void ${kernel}(${params})
)__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
......@@ -126,34 +105,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
g.add_point_op("mod", "migraphx::mod(${0}, ${1})");
g.add_point_op("fmod", "migraphx::fmod(${0}, ${1})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
auto* pm = ins->module_inputs().front();
auto pf = generate_pointwise(*pm, "inner_pointwise");
std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
}
}
};
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,7 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,8 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun,
std::size_t global,
std::size_t local,
void* kernargs,
std::size_t size)
std::size_t size,
hipEvent_t start,
hipEvent_t stop)
{
assert(global > 0);
assert(local > 0);
......@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun,
#endif
};
auto status = hipExtModuleLaunchKernel(
fun, global, 1, 1, local, 1, 1, 0, stream, nullptr, reinterpret_cast<void**>(&config));
auto status = hipExtModuleLaunchKernel(fun,
global,
1,
1,
local,
1,
1,
0,
stream,
nullptr,
reinterpret_cast<void**>(&config),
start,
stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status));
if(stop != nullptr)
{
status = hipEventSynchronize(stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to sync event: " + hip_error(status));
}
}
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
std::vector<void*> args) const
std::vector<void*> args,
hipEvent_t start,
hipEvent_t stop) const
{
assert(impl != nullptr);
void* kernargs = args.data();
std::size_t size = args.size() * sizeof(void*);
launch_kernel(impl->fun, stream, global, local, kernargs, size);
launch_kernel(impl->fun, stream, global, local, kernargs, size, start, stop);
}
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
const std::vector<kernel_argument>& args) const
const std::vector<kernel_argument>& args,
hipEvent_t start,
hipEvent_t stop) const
{
assert(impl != nullptr);
std::vector<char> kernargs = pack_args(args);
std::size_t size = kernargs.size();
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size, start, stop);
}
} // namespace gpu
......
......@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
{
return last;
}
if(!(*it == *s_it))
if(not(*it == *s_it))
{
break;
}
......
......@@ -153,7 +153,7 @@ struct array
return true;
}
friend constexpr bool operator!=(const array& x, const array& y) { return !(x == y); }
friend constexpr bool operator!=(const array& x, const array& y) { return not(x == y); }
// This uses the product order rather than lexical order
friend constexpr bool operator<(const array& x, const array& y)
{
......
......@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx {
......
......@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(and)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(or)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(not )
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template <index_int Axis,
class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
......@@ -90,7 +90,7 @@ struct lowest
template <class T>
constexpr operator T() const
{
return numeric_lowest<T>();
return numeric_lowest<vec_type<T>>();
}
};
......@@ -99,7 +99,7 @@ struct highest
template <class T>
constexpr operator T() const
{
return numeric_max<T>();
return numeric_max<vec_type<T>>();
}
};
} // namespace migraphx
......
......@@ -224,6 +224,18 @@ struct block
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
};
template <class Slicer>
......@@ -281,6 +293,13 @@ struct lane
}
});
}
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
};
template <class Slicer>
......
......@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
template <class... Ts>
using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; }
template <class T>
template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
is_same<T, migraphx::half>{})>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
......@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx
#endif
......@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
return vec_type<T>{x};
else
{
vec_type<T> result = x[0];
......
......@@ -27,42 +27,24 @@
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp>
#include <migraphx/gpu/logical_or.hpp>
#include <migraphx/gpu/logical_xor.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
......@@ -341,7 +323,7 @@ struct miopen_apply
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format);
compile_quant_conv_with_format(not int8_x4_format);
}
auto args = ins->inputs();
......
......@@ -78,7 +78,7 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return !(x == y); }
friend bool operator!=(ptr x, ptr y) { return not(x == y); }
T obj{};
};
......@@ -503,7 +503,7 @@ struct mlir_program
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
std::string tuned = get_tune_params();
if(!tuned.empty())
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
// check if HW supports xdlops
if(contains(get_xdlops_archs(), target_name))
......
......@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(!transb)
if(not transb)
{
auto packed_b = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
......
......@@ -23,13 +23,55 @@
*/
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
......@@ -39,59 +81,30 @@ struct find_layernorm
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
m.replace_instruction(ins, layernorm{}, x_ins);
}
};
struct find_triaddlayernorm
struct find_add_layernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
return match::layernorm()(match::var("x")(match::name("add").bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto ins = r.result;
auto add_ins = r.instructions["add"];
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
match::find_matches(m, find_add_layernorm{}, find_layernorm{});
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment