Commit f3fcfcc7 authored by Alan Turner's avatar Alan Turner
Browse files

Fix fusion pass and add tuning

parent 3133fd79
...@@ -90,6 +90,7 @@ add_library(migraphx_gpu ...@@ -90,6 +90,7 @@ add_library(migraphx_gpu
device_name.cpp device_name.cpp
elu.cpp elu.cpp
fuse_ck.cpp fuse_ck.cpp
fuse_ck_gemm_softmax_gemm.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
......
...@@ -49,43 +49,43 @@ struct ck_gemm ...@@ -49,43 +49,43 @@ struct ck_gemm
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_scale_bias_softmax_gemm // struct ck_gemm_scale_bias_softmax_gemm
{ // {
operation op = make_op("dot"); // operation op = make_op("dot");
template <class Self, class F> // template <class Self, class F>
static auto reflect(Self& self, F f) // static auto reflect(Self& self, F f)
{ // {
return pack(f(self.op, "op")); // return pack(f(self.op, "op"));
} // }
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; } // std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const // void check_gemm_shape(const shape& s) const
{ // {
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1)) // if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm"); // MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
} // }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const // shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ // {
check_shapes{inputs, *this}.same_ndims(); // check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1) // // if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule."); // // MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) // if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); // MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0]; // auto a = inputs[0];
auto b = inputs[1]; // auto b = inputs[1];
auto b1 = inputs[2]; // auto b1 = inputs[2];
for(const auto& input : inputs) // for(const auto& input : inputs)
{ // {
// std::cout << input << std::endl; // // std::cout << input << std::endl;
check_gemm_shape(input); // check_gemm_shape(input);
} // }
return op.compute_shape({op.compute_shape({a, b}), b1}); // return op.compute_shape({op.compute_shape({a, b}), b1});
} // }
}; // };
MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm); // MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace { namespace {
...@@ -156,38 +156,38 @@ struct find_ck_gemm ...@@ -156,38 +156,38 @@ struct find_ck_gemm
struct find_ck_gemm_scale_bias_softmax_gemm struct find_ck_gemm_scale_bias_softmax_gemm
{ {
auto matcher() const // auto matcher() const
{ // {
auto gemm1 = // auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); // match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw = // auto pw =
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias"); // match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); // auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( // return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax)); // match::any_of[match::inputs()](softmax));
} // }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const // void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ // {
std::cout << "Matched" << std::endl; // std::cout << "Matched" << std::endl;
auto ins = r.result; // auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"]; // auto gemm2_ins = r.instructions["gemm2"];
auto sm_ins = r.instructions["softmax"]; // auto sm_ins = r.instructions["softmax"];
auto pw_ins = r.instructions["scale_bias"]; // auto pw_ins = r.instructions["scale_bias"];
auto gemm1_ins = r.instructions["gemm1"]; // auto gemm1_ins = r.instructions["gemm1"];
gemm2_ins->debug_print(); // gemm2_ins->debug_print();
sm_ins->debug_print(); // sm_ins->debug_print();
pw_ins->debug_print(); // pw_ins->debug_print();
gemm1_ins->debug_print(); // gemm1_ins->debug_print();
auto inputs = gemm1_ins->inputs(); // A, B // auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1 // inputs.push_back(gemm2_ins->inputs().back()); // B1
// inputs.push_back(pw_ins->inputs().back()); // C // // inputs.push_back(pw_ins->inputs().back()); // C
mpm.get_module().replace_instruction( // mpm.get_module().replace_instruction(
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs); // ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
} // }
// auto matcher() const // auto matcher() const
// { // {
...@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
// mpm.get_module().debug_print(); // mpm.get_module().debug_print();
match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{}); // match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{})) if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
// match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_pointwise{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{})) if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
// match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
} // namespace gpu } // namespace gpu
......
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct gemm_softmax_gemm_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for gemm_softmax_gemm_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
};
MIGRAPHX_REGISTER_OP(gemm_softmax_gemm_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048)
return false;
return true;
}
struct find_gemm_softmax_gemm_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul =
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, gemm_softmax_gemm_gemm{gemm2_ins->get_operator()}, inputs);
}
};
} // namespace
void fuse_ck_gemm_softmax_gemm::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_gemm_softmax_gemm_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
namespace gpu {
struct fuse_ck_gemm_softmax_gemm
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ck_gemm_softmax_gemm"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
const std::vector<std::string>& const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred); get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -62,34 +62,12 @@ namespace migraphx { ...@@ -62,34 +62,12 @@ namespace migraphx {
${preamble} ${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<gemm, ${blocks_per_batch}>(xs...); ck_gemm_softmax_gemm<CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${instance}>, ${blocks_per_batch}>(xs...);
}); });
} }
...@@ -112,13 +90,13 @@ struct instance ...@@ -112,13 +90,13 @@ struct instance
std::size_t get_pb(std::size_t i) const std::size_t get_pb(std::size_t i) const
{ {
assert(i < 4); assert(i <= 4);
return int_at(block_size_index + 1 + i); return int_at(block_size_index + 1 + i);
} }
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const std::array<std::size_t, 4> get_pad(const std::array<std::size_t, 4>& config) const
{ {
std::array<std::size_t, 3> result{}; std::array<std::size_t, 4> result{};
for(auto i : range(config.size())) for(auto i : range(config.size()))
{ {
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i]; result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
...@@ -126,33 +104,16 @@ struct instance ...@@ -126,33 +104,16 @@ struct instance
return result; return result;
} }
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const std::size_t get_grid_size(const std::array<std::size_t, 4>& config) const
{ {
return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1)); return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[3], get_pb(3));
}
void set_ds_layout(const std::string& s)
{
assert(params[2] == "ck::Tuple<>");
params[2] = s;
}
void set_ds_type(const std::string& s)
{
assert(params[8] == "ck::Tuple<>");
params[8] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
params[12] = s;
} }
void set_gemm(const std::string& s) void set_gemm(const std::string& s)
{ {
assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default"); assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or
params[13] = s; params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding");
params[15] = s;
} }
std::string str() const { return join_strings(params, ","); } std::string str() const { return join_strings(params, ","); }
...@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{ {
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, "")); static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty()) if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl; std::cout << "*********** Warning: No CK GSG tuning!" << std::endl;
auto it = std::find_if( auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; }); tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end()) if(it == tuning.end())
{ {
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl; std::cout << "*********** Warning: CK GSG tuning missing for config!" << std::endl;
return 4; return 4;
} }
return it->second; return it->second;
...@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor" if (not s.transposed())
: "ck::tensor_layout::gemm::RowMajor"; return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ?
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
...@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto b1_shape = inputs[2];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto m = a_shape.lens()[0]; auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1]; auto k = a_shape.lens()[1];
...@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
// std::array<char, 3> keys{'M', 'N', 'K'}; std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// std::array<std::size_t, 3> config{ // config (m0, n0, k0, n1)
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()}; std::array<std::size_t, 4> config{
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool { auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// })}; get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
// assert(inputs.size() < 4 or v.contains("post")); })};
// if(v.contains("post"))
// { auto padding = ip.get_pad(config);
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout)); std::string gemm_type;
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type)); for(auto i : range(padding.size()))
// ip.set_ds_op(v.at("post").to<std::string>()); {
// } if(padding[i] != 0)
gemm_type += keys[i];
// auto padding = ip.get_pad(config); }
// std::string gemm_type; if(gemm_type.empty())
// for(auto i : range(padding.size())) gemm_type = "Default";
// { else
// if(padding[i] != 0) gemm_type += "Padding";
// gemm_type += keys[i]; ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
// }
// if(gemm_type.empty()) auto blocks_per_batch = ip.get_grid_size(config);
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto gemm1_nperblock = 64;
auto gemm01_mperblock = 256;
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) *
int_div_ceil(n, gemm1_nperblock); // ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2, auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(), c_shape.lens().rend(),
std::size_t{1}, std::size_t{1},
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
hip_compile_options options; hip_compile_options options;
auto block_size = 256; // ip.get_block_size(); auto block_size = ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch; auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
...@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel, auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"instance", "" /* ip.str() */}, {{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
...@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, " "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"; "post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
...@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] { return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{})) if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{ {
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()}; std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes[2], shapes.back()};
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl; << std::endl;
} }
......
This diff is collapsed.
...@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{}; constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[0]; // col-major constexpr const auto n = b_shape.lens[1];
constexpr const auto sb = b_shape.strides[1]; // col-major constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1(); constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1; constexpr const auto BK0 = k / BK1;
...@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{})); ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{}; constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0]; // row-major constexpr const auto k1 = b1_shape.lens[0];
constexpr const auto n1 = b1_shape.lens[1]; // row-major constexpr const auto n1 = b1_shape.lens[1];
constexpr const auto sb1 = b1_shape.strides[0]; // rowl-major constexpr const auto sb1 = b1_shape.strides[0]; // row-major
constexpr const auto B1K1 = gemm.get_B1K1(); constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1; constexpr const auto B1K0 = k1 / B1K1;
...@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, // static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, // b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, // b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n, // c_grid_desc_m_n,
block_2_ctile_map)); // block_2_ctile_map));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()), GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()), to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()), to_ck_const_pointer(b1.data()),
......
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp> #include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, fuse_mlir{&ctx},
......
...@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {16, 12, 384, 64}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {16, 12, 384, 384}};
// auto m2_elements = 16 * 12 * 384 * 384;
// auto a = mm->add_parameter("1", m1_shape);
// auto b = mm->add_parameter("2", m1_shape);
// auto b1 = mm->add_parameter("3", m1_shape);
// auto c = mm->add_parameter("4", m1_shape);
// std::vector<float> eights(m2_elements, 0.125);
// auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
// std::vector<float> zeros(m2_elements, 0);
// auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
// std::vector<float> ones(m2_elements, 1);
// auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; size_t batch = 2;
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; migraphx::shape m1_shape{migraphx::shape::half_type, {batch, 384, 2304}};
auto m2_elements = 1 * 12 * 256 * 256; migraphx::shape m2_shape{migraphx::shape::half_type, {batch, 12, 384, 384}};
auto a = mm->add_parameter("1", m1_shape); auto m2_elements = batch * 12 * 384 * 384;
auto b = mm->add_parameter("2", m1_shape); auto g = mm->add_parameter("1", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
std::vector<float> eights(m2_elements, 0.125); std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights}); auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0); std::vector<float> zeros(m2_elements, 0);
...@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
std::vector<float> ones(m2_elements, 1); std::vector<float> ones(m2_elements, 1);
auto one = mm->add_literal(migraphx::literal{m2_shape, ones}); auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
......
import os, json, subprocess, tempfile, sys, argparse, contextlib import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1
@contextlib.contextmanager @contextlib.contextmanager
def tmp_file(dump=None): def tmp_file(dump=None):
...@@ -20,6 +21,9 @@ def pretty_print(obj): ...@@ -20,6 +21,9 @@ def pretty_print(obj):
def run_driver(b): def run_driver(b):
print(b) print(b)
outfile = open("temp2.json", "w")
json.dump(b, outfile)
outfile.close()
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, capture_output=True,
...@@ -45,7 +49,7 @@ def get_device_time(s): ...@@ -45,7 +49,7 @@ def get_device_time(s):
def benchmark_ck(config, tuning): def benchmark_ck(config, tuning):
try: try:
b = { b0 = {
'settings': { 'settings': {
'iterations': 100 'iterations': 100
}, },
...@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning): ...@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning):
'inputs': config 'inputs': config
} }
} }
b1 = {
'settings': {
'iterations': 100
},
'compile_op': {
'name': 'ck_gemm_softmax_gemm',
'check': True,
'tuning_val': tuning,
'inputs': config
}
}
b = b0 if (ck_function == 0) else b1
for line in run_driver(b): for line in run_driver(b):
dtime = get_device_time(line) dtime = get_device_time(line)
print(dtime) print(dtime)
...@@ -72,17 +88,26 @@ def benchmark(config, size): ...@@ -72,17 +88,26 @@ def benchmark(config, size):
def parse_log(f): def parse_log(f):
for line in open(f).readlines(): for line in open(f).readlines():
line = line.strip() line = line.strip()
if not line.startswith('ck_gemm:'): global ck_function
continue if line.startswith('ck_gemm:'):
line = line[len('ck_gemm:'):].strip() line = line[len('ck_gemm:'):].strip()
config = json.loads(line) config = json.loads(line)
yield config ck_function = 0
yield config
if line.startswith('ck_gemm_softmax_gemm:'):
line = line[len('ck_gemm_softmax_gemm:'):].strip()
config = json.loads(line)
ck_function = 1
yield config
def benchmark_log(f, n): def benchmark_log(f, n):
result = [] result = []
for config in parse_log(f): logs = parse_log(f)
tuned = benchmark(config, n) for config in logs:
additional_tv = ck_function * 2
tuned = benchmark(config, n + additional_tv)
print("Tuned:", tuned) print("Tuned:", tuned)
result.append([config, tuned]) result.append([config, tuned])
return result return result
......
#!/bin/bash
MODEL=$1
LOG="ck_bbc.log"
TUNING_DB="ck_bbc.json"
rm $LOG
touch $LOG
for N in 1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM=1 ./bin/driver run $MODEL -g --fill1 input_ids --input-dim @input_ids $N 384 | grep 'ck_gemm.*: \[{' | sort -u >> $LOG
done
python3 ../tools/tune_ck.py -n 16 -l $LOG -o $TUNING_DB
\ No newline at end of file
import os, json, subprocess, tempfile, sys, argparse, contextlib
@contextlib.contextmanager
def tmp_file(dump=None):
tmp_name = None
try:
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
tmp_name = f.name
if dump:
dump(f)
yield tmp_name
finally:
os.unlink(tmp_name)
def pretty_print(obj):
print(json.dumps(obj, indent=2))
def run_driver(b):
print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True,
check=True,
shell=True)
for line in cp.stdout.decode().split("\n"):
s = line.strip()
if not s:
continue
if not ']: ' in s:
continue
yield s.split(']: ')[1].strip()
def convert_to_float(s):
return s[:-2]
def get_device_time(s):
fields = s.split(',')
return convert_to_float(fields[-1].strip())
def benchmark_ck(config, tuning):
try:
b = {
'settings': {
'iterations': 100
},
'compile_op': {
'name': 'ck_gemm_softmax_gemm',
'check': True,
'tuning_val': tuning,
'inputs': config
}
}
for line in run_driver(b):
dtime = get_device_time(line)
print(dtime)
return float(dtime)
except:
return sys.float_info.max
def benchmark(config, size):
times = [benchmark_ck(config, i) for i in range(size)]
return times.index(min(times))
def parse_log(f):
for line in open(f).readlines():
line = line.strip()
if not line.startswith('ck_gemm_softmax_gemm:'):
continue
line = line[len('ck_gemm_softmax_gemm:'):].strip()
config = json.loads(line)
yield config
def benchmark_log(f, n):
result = []
for config in parse_log(f):
tuned = benchmark(config, n)
print("Tuned:", tuned)
result.append([config, tuned])
return result
def parse_args():
parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
parser.add_argument('--log',
'-l',
type=str,
metavar='file',
help='Path to logfile')
parser.add_argument('--out',
'-o',
type=str,
metavar='file',
help='Output json file to save tunings')
parser.add_argument('-n', type=int, help='Number of instances to tune')
args = parser.parse_args()
return args
def run(args):
tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+'))
run(parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment