Commit a22ec139 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents 9898823d 650ba45f
......@@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary
}
};
// simplifies broadcast->transpose to transpose->broadcast
// in the case of a scalar, simply rewrite to broadcast
// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp
struct find_broadcast_transpose
{
auto matcher() const
......@@ -642,17 +645,30 @@ struct find_broadcast_transpose
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
// scalar transformation does not need extra transpose
if(not input->get_shape().scalar())
return;
{
// find common shape
auto in_lens = input->get_shape().lens();
int lens_diff = transpose_lens.size() - in_lens.size();
// insert unsqueeze if input lens < transpose lens
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
// apply transpose before the multibroadcast
input = m.insert_instruction(bcast_ins, transpose->get_operator(), input);
}
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
bcast_ins, make_op("multibroadcast", {{"out_lens", transpose_lens}}), input);
m.replace_instruction(transpose, new_mbcast);
}
};
......
......@@ -91,6 +91,19 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
......@@ -308,7 +321,7 @@ struct dnnl_op : auto_register_op<Derived>
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = [=](context&, const std::vector<argument>& args) {
execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
......@@ -379,7 +392,7 @@ struct dnnl_op : auto_register_op<Derived>
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
};
});
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
......
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/cpu/context.hpp>
#include <string>
namespace migraphx {
......@@ -34,9 +34,7 @@ struct module;
namespace cpu {
struct context;
struct fuse_ops
struct MIGRAPHX_CPU_EXPORT fuse_ops
{
context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; }
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp>
......
......@@ -185,8 +185,7 @@ struct compile_plan
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
......
......@@ -38,10 +38,8 @@ struct compile_op : action<compile_op>
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms";
std::cout << std::endl;
}
};
......
......@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name);
if(v.contains("fields"))
op.from_value(v.at("fields"));
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl;
auto t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
}
};
......
......@@ -22,9 +22,9 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
......@@ -55,7 +55,7 @@ struct ck_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
......@@ -65,21 +65,27 @@ struct ck_gemm
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
namespace {
bool is_ck_supported_type(shape::type_t t)
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
......@@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end());
......@@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -161,11 +171,26 @@ struct find_ck_gemm
}
};
struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
......
......@@ -36,24 +36,14 @@ struct module;
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
return not mlir_disabled;
#else
return false;
#endif
......@@ -131,9 +121,16 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
op_stream.push_back(input->get_operator());
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
......@@ -150,18 +147,62 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
enum class mlir_mode
{
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
return true;
all,
fast,
int8,
none
};
auto is_mlir_dot(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
}
std::unordered_map<instruction_ref, instruction_ref>
......@@ -273,11 +314,12 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
......@@ -323,9 +365,12 @@ struct find_mlir_fused_ops
}
};
template <auto Matcher>
struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const
{
static size_t counter = 0;
......@@ -340,6 +385,7 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
//
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
......@@ -351,18 +397,12 @@ struct find_mlir_standalone_op
}
};
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
struct find_mlir_standalone_attention_op
{
mlir_mode mode = mlir_mode::none;
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
......@@ -373,6 +413,7 @@ struct find_mlir_standalone_attention_op
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["top_dot"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end());
ins_map[top_ins] = new_top_ins;
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_ins = r.instructions["scale"];
......@@ -401,7 +442,6 @@ struct find_mlir_standalone_attention_op
new_bottom_dot = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
}
mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
}
......@@ -413,6 +453,11 @@ struct find_mlir_standalone_attention_op
}
bool check(const match::matcher_result& r) const {
// We are only enabling attention
// in the highest enablement mode for now
if(mode != mlir_mode::all){
return false;
}
auto top_dot = r.instructions["top_dot"];
// Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){
......@@ -491,44 +536,15 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option)
bool is_requested(std::string_view option, bool fallback = false)
{
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace
#endif // MIGRAPHX_MLIR
......@@ -536,22 +552,34 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_fused_ops{});
}
const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx110");
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(is_requested(option))
return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
mlir_mode mode =
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
// Attention offloads; default disabled
match::find_matches(mpm,
find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm,
find_mlir_standalone_attention_op{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else
(void)mpm;
#endif
......
......@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp>
#include <array>
#include <iostream>
#include <cstring>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}
static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
......@@ -299,23 +299,6 @@ struct context
any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
if(measure_perf)
......@@ -323,12 +306,12 @@ struct context
return std::make_pair(nullptr, nullptr);
}
float get_elapsed_ms() const
static float get_elapsed_ms(hipEvent_t start, hipEvent_t stop)
{
float result = 0;
if(start_event != nullptr and stop_event != nullptr)
if(start != nullptr and stop != nullptr)
{
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get());
auto status = hipEventElapsedTime(&result, start, stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
}
......
......@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const;
};
......
......@@ -24,7 +24,6 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
......@@ -34,7 +33,7 @@ struct module;
namespace gpu {
struct fuse_ops
struct MIGRAPHX_GPU_EXPORT fuse_ops
{
context* ctx = nullptr;
bool fast_math = true;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct gemm_softmax_gemm
{
operation op = make_op("dot");
float scale = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.scale, "scale"));
}
std::string name() const { return "gpu::gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for " + name());
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
......@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/config.hpp>
#include <string>
namespace migraphx {
......@@ -34,7 +34,7 @@ struct module_pass_manager;
namespace gpu {
struct prefuse_ops
struct MIGRAPHX_GPU_EXPORT prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module_pass_manager& mpm) const;
......
......@@ -32,7 +32,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_GPU_EXPORT std::pair<double, double>
MIGRAPHX_GPU_EXPORT double
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace gpu
......
......@@ -27,6 +27,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/compile_gen.hpp>
......@@ -37,8 +38,6 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include "ck/host/device_gemm_multiple_d.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -46,12 +45,6 @@ namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
......@@ -79,219 +72,10 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__";
// NOLINTNEXTLINE
static const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
static std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.first, p.second}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto&& p) {
return std::make_pair(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto&& p) {
return src_file{p.first, p.second};
});
return srcs;
}
static const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
if(not fs::exists(s))
return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}
static float matrix_distance(const shape& x, const shape& y)
{
if(x.type() != y.type())
return std::numeric_limits<float>::max();
if(transposed_matrix(x) != transposed_matrix(y))
return std::numeric_limits<float>::max();
auto sum_squared = std::inner_product(x.lens().rbegin(),
x.lens().rbegin() + 2,
y.lens().rbegin(),
0,
std::plus<>{},
[](auto a, auto b) { return (a - b) * (a - b); });
return std::sqrt(sum_squared);
}
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
{
std::cout << "*********** Warning: No CK tuning! for config:" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
}
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
MIGRAPHX_THROW("Invalid CK config");
auto avg_distance = std::inner_product(
p.first.begin(),
p.first.begin() + 3,
inputs.begin(),
0.0f,
std::plus<>{},
[](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
if(not w.empty())
default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
return tuning_val;
}
return it->second;
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}
template <class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
std::transform(start, last, std::back_inserter(s), f);
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
{
swap_inputs = false;
auto c_shape = inputs.back();
if(not transposed_matrix(c_shape))
return inputs;
std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
return reorder_shape(s, perm);
});
swap_inputs = true;
return inputs;
}
static std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
static void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}
static void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
bool can_fold_batch(const std::vector<shape>& inputs) const
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const
{
......@@ -300,8 +84,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& c_shape = inputs.back();
// cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m;
......@@ -351,12 +134,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto tuning_value = v.get("tuning_value", 34);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <migraphx/filesystem.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
// NOLINTNEXTLINE
static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <${include}>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
auto settings = make_ck_gemm_softmax_gemm_settings(MIGRAPHX_MAKE_CONSTANT(float{SCALE}));
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(settings, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
ck::host::device_batched_gemm_softmax_gemm::Problem
create_problem(const std::vector<shape>& inputs, const value&) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back();
// cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
m = can_fold_batch(inputs) ? m * batch_count : m;
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto o = c_shape.lens().back();
const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape);
const bool trans_b1 = transposed_matrix(b1_shape);
const bool trans_c = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape);
std::string ck_passthrough = "ck_passthrough";
return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n,
k,
o,
trans_a,
trans_b,
trans_b1,
trans_c,
a_type,
b_type,
b1_type,
c_type,
ck_passthrough,
ck_passthrough,
ck_passthrough,
ck_passthrough};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 5);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
hip_compile_options options;
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_softmax_gemm_kernel");
options.virtual_inputs = inputs;
if(can_fold_batch(inputs))
{
auto vinputs = inputs;
fold_batch_dims(vinputs[0]);
remove_batch_dims(vinputs[1]);
std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
options.virtual_inputs = vinputs;
}
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1";
// scale
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
options.params += " -DSCALE=" + std::to_string(scale);
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"solution", template_str},
{"include", include_header},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
value create_settings(instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["kernel"] = "ck_gemm_softmax_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
return v;
}
compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
{
auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op);
if(not solution.is_null())
v["tuning_value"] = solution;
return {compile_op(ctx, shapes, v),
[=](module& m, instruction_ref ins2, const operation& code_object) {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: "
<< to_json_string(to_value(gemm_shapes)) << std::endl;
}
m.replace_instruction(ins2, code_object, ins2->inputs());
}};
}
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
{
if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
return nullopt;
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
tc.problem = to_value(gemm_shapes);
return tc;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -154,6 +154,17 @@ struct ck_add
}
};
// In CK, the B matrix is ordered as N,K instead of K,N
template <class Dims>
constexpr auto ck_transposeb_dims(Dims dims)
{
return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); });
}
template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment