Commit 477e0162 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-layernorm' into bert-opt2

parents 0db1af37 7b332efc
...@@ -25,6 +25,13 @@ ...@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -75,25 +82,26 @@ std::string vectorize::str() const ...@@ -75,25 +82,26 @@ std::string vectorize::str() const
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result(inputs.size());
std::transform(inputs.begin(), std::vector<std::size_t> preloaded;
inputs.end(), for(auto i : range(inputs.size()))
std::back_inserter(result), {
[&](const shape& input) { return input.strides()[axis] == 0; }); if(inputs[i].strides()[axis] == 0)
auto bytes = std::inner_product(inputs.begin(), preloaded.push_back(i);
inputs.end(), }
result.begin(), std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
std::size_t{0}, return inputs[i].bytes();
std::plus<>{}, }));
[](const shape& s, bool b) -> std::size_t {
if(b) std::size_t bytes = 0;
return s.bytes(); for(auto i : preloaded)
return 0; {
}); auto input = inputs[i];
if(bytes < max_lds_bytes) bytes += input.bytes();
return {result}; if(bytes > max_lds_bytes)
// TODO: Try to partially preload items break;
std::fill(result.begin(), result.end(), false); result[i] = true;
}
return {result}; return {result};
} }
...@@ -125,6 +133,45 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -125,6 +133,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
return g.str();
}
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
std::string generate_name_from_ops(const module& m)
{
auto op_names = get_op_names(m);
return join_strings(op_names, "_");
}
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,6 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG_SYM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -227,6 +228,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
params += " -fno-gpu-rdc"; params += " -fno-gpu-rdc";
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) if(is_hcc_compiler())
{ {
......
...@@ -831,13 +831,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) ...@@ -831,13 +831,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT template <class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) { return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
return false; return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op")); auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s); return (contains({names...}, op.name()));
}); });
} }
...@@ -1112,6 +1113,31 @@ struct find_contiguous_pointwise ...@@ -1112,6 +1113,31 @@ struct find_contiguous_pointwise
} }
}; };
struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty())
return;
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
...@@ -1134,6 +1160,7 @@ void fuse_ops::apply(module& m) const ...@@ -1134,6 +1160,7 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs) ...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return make_transformer_args({xs.str()...}); return make_transformer_args({xs.str()...});
} }
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_name_from_ops(const module& m);
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(faxis, inputs);
}
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if(op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params}) ...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
...@@ -127,31 +115,13 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -127,31 +115,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); auto pf = generate_pointwise(*pm, "inner_pointwise");
cpp_generator g; std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); return replace(
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); compile_op(ctx,
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
} }
} }
}; };
......
...@@ -32,7 +32,8 @@ ...@@ -32,7 +32,8 @@
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx { namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template <index_int Axis,
class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements =
get_shape_c<Input1>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
...@@ -24,16 +24,86 @@ ...@@ -24,16 +24,86 @@
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
m.replace_instruction(ins, layernorm{}, x_ins);
}
};
struct find_add_layernorm
{
auto matcher() const
{
return match::layernorm()(match::var("x")(match::name("add").bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto add_ins = r.instructions["add"];
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
}
};
struct find_gpulayernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -53,7 +123,7 @@ struct find_layernorm ...@@ -53,7 +123,7 @@ struct find_layernorm
} }
}; };
struct find_triaddlayernorm struct find_gputriaddlayernorm
{ {
auto matcher() const auto matcher() const
{ {
...@@ -91,7 +161,8 @@ struct find_triaddlayernorm ...@@ -91,7 +161,8 @@ struct find_triaddlayernorm
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); match::find_matches(m, find_add_layernorm{}, find_layernorm{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment