Commit a22ec139 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents 9898823d 650ba45f
...@@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
// simplifies broadcast->transpose to transpose->broadcast
// in the case of a scalar, simply rewrite to broadcast
// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp
struct find_broadcast_transpose struct find_broadcast_transpose
{ {
auto matcher() const auto matcher() const
...@@ -642,17 +645,30 @@ struct find_broadcast_transpose ...@@ -642,17 +645,30 @@ struct find_broadcast_transpose
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto transpose = r.result;
auto ins_lens = ins->get_shape().lens(); auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation // scalar transformation does not need extra transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
return; {
// find common shape
auto in_lens = input->get_shape().lens();
int lens_diff = transpose_lens.size() - in_lens.size();
// insert unsqueeze if input lens < transpose lens
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
// apply transpose before the multibroadcast
input = m.insert_instruction(bcast_ins, transpose->get_operator(), input);
}
auto new_mbcast = m.insert_instruction( auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input); bcast_ins, make_op("multibroadcast", {{"out_lens", transpose_lens}}), input);
m.replace_instruction(ins, new_mbcast); m.replace_instruction(transpose, new_mbcast);
} }
}; };
......
...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op> ...@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
} }
}; };
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive> template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived> struct dnnl_op : auto_register_op<Derived>
{ {
...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG #ifndef NDEBUG
auto prim_attr = get_primitive_attr(md); auto prim_attr = get_primitive_attr(md);
#endif #endif
execute = [=](context&, const std::vector<argument>& args) { execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG #ifndef NDEBUG
// Check that the memory descriptors have not changed // Check that the memory descriptors have not changed
auto debug_args = args; auto debug_args = args;
...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m); prim.execute(get_dnnl_context().stream, m);
return args.back(); return args.back();
}; });
} }
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{ {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP #define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/cpu/context.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -34,9 +34,7 @@ struct module; ...@@ -34,9 +34,7 @@ struct module;
namespace cpu { namespace cpu {
struct context; struct MIGRAPHX_CPU_EXPORT fuse_ops
struct fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; } std::string name() const { return "cpu::fuse_ops"; }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
......
...@@ -185,8 +185,7 @@ struct compile_plan ...@@ -185,8 +185,7 @@ struct compile_plan
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value()) if(not cr.has_value())
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20) return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
.first;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
......
...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op> ...@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms"; std::cout << op << ": " << t << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl; std::cout << std::endl;
} }
}; };
......
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
...@@ -55,7 +55,7 @@ struct ck_gemm ...@@ -55,7 +55,7 @@ struct ck_gemm
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
for(const auto& input : inputs) for(const auto& input : inputs)
...@@ -65,21 +65,27 @@ struct ck_gemm ...@@ -65,21 +65,27 @@ struct ck_gemm
return r; return r;
return r.with_type(mods.front()->get_output_shapes().front().type()); return r.with_type(mods.front()->get_output_shapes().front().type());
} }
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
namespace { struct ck_gemm_softmax_gemm : gemm_softmax_gemm
bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
} };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{ {
if(ins->name() != "dot" and ins->name() != "quant_dot") if(ins->name() != "dot" and ins->name() != "quant_dot")
return false; return false;
if(not is_ck_supported_type(ins->get_shape().type())) if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
...@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise ...@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type()) ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type()); return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
...@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise ...@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm struct find_ck_gemm
{ {
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
...@@ -161,11 +171,26 @@ struct find_ck_gemm ...@@ -161,11 +171,26 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
......
...@@ -36,24 +36,14 @@ struct module; ...@@ -36,24 +36,14 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled() bool mlir_enabled()
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
if(mlir_enabled) return not mlir_disabled;
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else #else
return false; return false;
#endif #endif
...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
for(instruction_ref input : gemm_based_op->inputs()) for(instruction_ref input : gemm_based_op->inputs())
{ {
std::vector<operation> op_stream; std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{ {
op_stream.push_back(input->get_operator()); op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0); input = input->inputs().at(0);
} }
top_inputs.push_back(input); top_inputs.push_back(input);
...@@ -150,8 +147,40 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -150,8 +147,40 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) enum class mlir_mode
{
all,
fast,
int8,
none
};
auto is_mlir_dot(mlir_mode mode)
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
...@@ -161,7 +190,19 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -161,7 +190,19 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true; return true;
return (w.lens()[3] % 3) != 0;
});
} }
std::unordered_map<instruction_ref, instruction_ref> std::unordered_map<instruction_ref, instruction_ref>
...@@ -273,11 +314,12 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -273,11 +314,12 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
struct find_mlir_fused_ops struct find_mlir_fused_ops
{ {
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -323,9 +365,12 @@ struct find_mlir_fused_ops ...@@ -323,9 +365,12 @@ struct find_mlir_fused_ops
} }
}; };
template <auto Matcher>
struct find_mlir_standalone_op struct find_mlir_standalone_op
{ {
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const
{ {
static size_t counter = 0; static size_t counter = 0;
...@@ -340,6 +385,7 @@ struct find_mlir_standalone_op ...@@ -340,6 +385,7 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
//
// enable only for fp32/fp16/i8 types // enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) { if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains(
...@@ -351,18 +397,12 @@ struct find_mlir_standalone_op ...@@ -351,18 +397,12 @@ struct find_mlir_standalone_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
{ using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
struct find_mlir_standalone_attention_op struct find_mlir_standalone_attention_op
{ {
mlir_mode mode = mlir_mode::none;
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
static size_t counter = 0; static size_t counter = 0;
...@@ -373,6 +413,7 @@ struct find_mlir_standalone_attention_op ...@@ -373,6 +413,7 @@ struct find_mlir_standalone_attention_op
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["top_dot"]; auto top_ins = r.instructions["top_dot"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins); auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end());
ins_map[top_ins] = new_top_ins; ins_map[top_ins] = new_top_ins;
if(r.instructions.find("scale") != r.instructions.end()){ if(r.instructions.find("scale") != r.instructions.end()){
auto scale_ins = r.instructions["scale"]; auto scale_ins = r.instructions["scale"];
...@@ -401,7 +442,6 @@ struct find_mlir_standalone_attention_op ...@@ -401,7 +442,6 @@ struct find_mlir_standalone_attention_op
new_bottom_dot = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0]; new_bottom_dot = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
} }
mm->add_return({new_bottom_dot}); mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm}); r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
} }
...@@ -413,6 +453,11 @@ struct find_mlir_standalone_attention_op ...@@ -413,6 +453,11 @@ struct find_mlir_standalone_attention_op
} }
bool check(const match::matcher_result& r) const { bool check(const match::matcher_result& r) const {
// We are only enabling attention
// in the highest enablement mode for now
if(mode != mlir_mode::all){
return false;
}
auto top_dot = r.instructions["top_dot"]; auto top_dot = r.instructions["top_dot"];
// Check the pointwise mod only contains a single mul // Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){ if(r.instructions.find("scale") != r.instructions.end()){
...@@ -491,44 +536,15 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op ...@@ -491,44 +536,15 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option) bool is_requested(std::string_view option, bool fallback = false)
{ {
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ','); const auto options = split_string(string_value, ',');
return contains(options, option); return contains(options, option);
} }
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif // MIGRAPHX_MLIR #endif // MIGRAPHX_MLIR
...@@ -536,22 +552,34 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -536,22 +552,34 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_attention_op{}); const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
if(is_enabled("fused", this->ctx)) const bool is_navi = starts_with(device_name, "gfx110");
{
match::find_matches(mpm, find_mlir_attention_fused_ops{}); auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
match::find_matches(mpm, find_mlir_fused_ops{}); if(is_requested(option))
} return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
if(is_enabled("convolution", this->ctx)) mlir_mode mode =
{ (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
if(is_enabled("dot", this->ctx)) // Attention offloads; default disabled
{ match::find_matches(mpm,
match::find_matches(mpm, find_mlir_standalone_dot_op{}); find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
} match::find_matches(mpm,
find_mlir_standalone_attention_op{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
...@@ -299,23 +299,6 @@ struct context ...@@ -299,23 +299,6 @@ struct context
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{ {
if(measure_perf) if(measure_perf)
...@@ -323,12 +306,12 @@ struct context ...@@ -323,12 +306,12 @@ struct context
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
float get_elapsed_ms() const static float get_elapsed_ms(hipEvent_t start, hipEvent_t stop)
{ {
float result = 0; float result = 0;
if(start_event != nullptr and stop_event != nullptr) if(start != nullptr and stop != nullptr)
{ {
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get()); auto status = hipEventElapsedTime(&result, start, stop);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status)); MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
} }
......
...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled(); ...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
}; };
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP #define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -34,7 +33,7 @@ struct module; ...@@ -34,7 +33,7 @@ struct module;
namespace gpu { namespace gpu {
struct fuse_ops struct MIGRAPHX_GPU_EXPORT fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true; bool fast_math = true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct gemm_softmax_gemm
{
operation op = make_op("dot");
float scale = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.scale, "scale"));
}
std::string name() const { return "gpu::gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for " + name());
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -34,7 +34,7 @@ struct module_pass_manager; ...@@ -34,7 +34,7 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
struct prefuse_ops struct MIGRAPHX_GPU_EXPORT prefuse_ops
{ {
std::string name() const { return "gpu::prefuse_ops"; } std::string name() const { return "gpu::prefuse_ops"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -32,7 +32,7 @@ namespace migraphx { ...@@ -32,7 +32,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::pair<double, double> MIGRAPHX_GPU_EXPORT double
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100); time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace gpu } // namespace gpu
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
...@@ -37,8 +38,6 @@ ...@@ -37,8 +38,6 @@
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,12 +45,6 @@ namespace gpu { ...@@ -46,12 +45,6 @@ namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
...@@ -79,219 +72,10 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -79,219 +72,10 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
// NOLINTNEXTLINE
static const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
static std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.first, p.second}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto&& p) {
return std::make_pair(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto&& p) {
return src_file{p.first, p.second};
});
return srcs;
}
static const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
if(not fs::exists(s))
return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static float matrix_distance(const shape& x, const shape& y)
{
if(x.type() != y.type())
return std::numeric_limits<float>::max();
if(transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(),
x.lens().rbegin() + 2,
y.lens().rbegin(),
0,
std::plus<>{},
[](auto a, auto b) { return (a - b) * (a - b); });
return std::sqrt(sum_squared);
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
{
std::cout << "*********** Warning: No CK tuning! for config:" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
}
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(
p.first.begin(),
p.first.begin() + 3,
inputs.begin(),
0.0f,
std::plus<>{},
[](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
if(not w.empty())
default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
return tuning_val;
}
return it->second;
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static std::string get_layout(const shape& s)
{
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
std::transform(start, last, std::back_inserter(s), f);
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
{
swap_inputs = false;
auto c_shape = inputs.back();
if(not transposed_matrix(c_shape))
return inputs;
std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
return reorder_shape(s, perm);
});
swap_inputs = true;
return inputs;
}
static std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
static void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
static void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
bool can_fold_batch(const std::vector<shape>& inputs) const
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const const value& v) const
{ {
...@@ -301,7 +85,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -301,7 +85,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
// cppcheck-suppress unreadVariable // cppcheck-suppress unreadVariable
auto rank = a_shape.ndim(); auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m; m = can_fold_batch(inputs) ? m * batch_count : m;
...@@ -351,12 +134,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -351,12 +134,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 34);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v); auto problem = create_problem(inputs, v);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <migraphx/filesystem.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE
static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <${include}>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
auto settings = make_ck_gemm_softmax_gemm_settings(MIGRAPHX_MAKE_CONSTANT(float{SCALE}));
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(settings, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
ck::host::device_batched_gemm_softmax_gemm::Problem
create_problem(const std::vector<shape>& inputs, const value&) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back();
// cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m;
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto o = c_shape.lens().back();
const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape);
const bool trans_b1 = transposed_matrix(b1_shape);
const bool trans_c = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape);
std::string ck_passthrough = "ck_passthrough";
return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n,
k,
o,
trans_a,
trans_b,
trans_b1,
trans_c,
a_type,
b_type,
b1_type,
c_type,
ck_passthrough,
ck_passthrough,
ck_passthrough,
ck_passthrough};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 5);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
hip_compile_options options;
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_softmax_gemm_kernel");
options.virtual_inputs = inputs;
if(can_fold_batch(inputs))
{
auto vinputs = inputs;
fold_batch_dims(vinputs[0]);
remove_batch_dims(vinputs[1]);
std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
options.virtual_inputs = vinputs;
}
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1";
// scale
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
options.params += " -DSCALE=" + std::to_string(scale);
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"solution", template_str},
{"include", include_header},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
value create_settings(instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_softmax_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
return v;
}
compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
{
auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op);
if(not solution.is_null())
v["tuning_value"] = solution;
return {compile_op(ctx, shapes, v),
[=](module& m, instruction_ref ins2, const operation& code_object) {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: "
<< to_json_string(to_value(gemm_shapes)) << std::endl;
}
m.replace_instruction(ins2, code_object, ins2->inputs());
}};
}
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
{
if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
return nullopt;
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
tc.problem = to_value(gemm_shapes);
return tc;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -154,6 +154,17 @@ struct ck_add ...@@ -154,6 +154,17 @@ struct ck_add
} }
}; };
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
#ifdef MIGRAPHX_CK_CHECK #ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert #define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment