Commit f3fcfcc7 authored by Alan Turner's avatar Alan Turner
Browse files

Fix fusion pass and add tuning

parent 3133fd79
......@@ -90,6 +90,7 @@ add_library(migraphx_gpu
device_name.cpp
elu.cpp
fuse_ck.cpp
fuse_ck_gemm_softmax_gemm.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
......
......@@ -49,43 +49,43 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_scale_bias_softmax_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
// std::cout << input << std::endl;
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
};
MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
// struct ck_gemm_scale_bias_softmax_gemm
// {
// operation op = make_op("dot");
// template <class Self, class F>
// static auto reflect(Self& self, F f)
// {
// return pack(f(self.op, "op"));
// }
// std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
// void check_gemm_shape(const shape& s) const
// {
// if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
// MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
// }
// shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
// {
// check_shapes{inputs, *this}.same_ndims();
// // if(mods.size() != 1)
// // MIGRAPHX_THROW("should have one submodule.");
// if(inputs.size() < 2)
// MIGRAPHX_THROW("should have at least two inputs.");
// auto a = inputs[0];
// auto b = inputs[1];
// auto b1 = inputs[2];
// for(const auto& input : inputs)
// {
// // std::cout << input << std::endl;
// check_gemm_shape(input);
// }
// return op.compute_shape({op.compute_shape({a, b}), b1});
// }
// };
// MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace {
......@@ -156,38 +156,38 @@ struct find_ck_gemm
struct find_ck_gemm_scale_bias_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw =
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
// auto matcher() const
// {
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
std::cout << "Matched" << std::endl;
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto sm_ins = r.instructions["softmax"];
auto pw_ins = r.instructions["scale_bias"];
auto gemm1_ins = r.instructions["gemm1"];
gemm2_ins->debug_print();
sm_ins->debug_print();
pw_ins->debug_print();
gemm1_ins->debug_print();
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
// inputs.push_back(pw_ins->inputs().back()); // C
mpm.get_module().replace_instruction(
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto pw_ins = r.instructions["scale_bias"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// pw_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// // inputs.push_back(pw_ins->inputs().back()); // C
// mpm.get_module().replace_instruction(
// ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
// auto matcher() const
// {
......@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm
void fuse_ck::apply(module_pass_manager& mpm) const
{
// mpm.get_module().debug_print();
match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
// match::find_matches(mpm, find_ck_gemm_pointwise{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
// match::find_matches(mpm, find_ck_gemm{});
// match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
match::find_matches(mpm, find_ck_gemm_pointwise{});
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
......
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct gemm_softmax_gemm_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for gemm_softmax_gemm_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}
};
MIGRAPHX_REGISTER_OP(gemm_softmax_gemm_gemm);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048)
return false;
return true;
}
struct find_gemm_softmax_gemm_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul =
match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add =
match::name("add")(match::any_of[match::inputs()](mul));
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, gemm_softmax_gemm_gemm{gemm2_ins->get_operator()}, inputs);
}
};
} // namespace
void fuse_ck_gemm_softmax_gemm::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_gemm_softmax_gemm_gemm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
namespace gpu {
struct fuse_ck_gemm_softmax_gemm
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ck_gemm_softmax_gemm"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
......@@ -39,7 +39,7 @@
#include <migraphx/file_buffer.hpp>
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -62,34 +62,12 @@ namespace migraphx {
${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" {
__global__ void ${kernel}(${params})
{
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<gemm, ${blocks_per_batch}>(xs...);
ck_gemm_softmax_gemm<CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${instance}>, ${blocks_per_batch}>(xs...);
});
}
......@@ -112,13 +90,13 @@ struct instance
std::size_t get_pb(std::size_t i) const
{
assert(i < 4);
assert(i <= 4);
return int_at(block_size_index + 1 + i);
}
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
std::array<std::size_t, 4> get_pad(const std::array<std::size_t, 4>& config) const
{
std::array<std::size_t, 3> result{};
std::array<std::size_t, 4> result{};
for(auto i : range(config.size()))
{
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
......@@ -126,33 +104,16 @@ struct instance
return result;
}
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
std::size_t get_grid_size(const std::array<std::size_t, 4>& config) const
{
return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
}
void set_ds_layout(const std::string& s)
{
assert(params[2] == "ck::Tuple<>");
params[2] = s;
}
void set_ds_type(const std::string& s)
{
assert(params[8] == "ck::Tuple<>");
params[8] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
params[12] = s;
return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[3], get_pb(3));
}
void set_gemm(const std::string& s)
{
assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
params[13] = s;
assert(params[15] == "ck::tensor_operation::device::GemmSpecialization::Default" or
params[15] == "ck::tensor_operation::device::GemmSpecialization::MNKOPadding");
params[15] = s;
}
std::string str() const { return join_strings(params, ","); }
......@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
std::cout << "*********** Warning: No CK GSG tuning!" << std::endl;
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::cout << "*********** Warning: CK GSG tuning missing for config!" << std::endl;
return 4;
}
return it->second;
......@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
if (not s.transposed())
return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2] ?
"ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
......@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto b1_shape = inputs[2];
auto c_shape = inputs.back();
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
......@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto rank = a_shape.lens().size();
// std::array<char, 3> keys{'M', 'N', 'K'};
// std::array<std::size_t, 3> config{
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
// })};
// assert(inputs.size() < 4 or v.contains("post"));
// if(v.contains("post"))
// {
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
// ip.set_ds_op(v.at("post").to<std::string>());
// }
// auto padding = ip.get_pad(config);
// std::string gemm_type;
// for(auto i : range(padding.size()))
// {
// if(padding[i] != 0)
// gemm_type += keys[i];
// }
// if(gemm_type.empty())
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto gemm1_nperblock = 64;
auto gemm01_mperblock = 256;
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) *
int_div_ceil(n, gemm1_nperblock); // ip.get_grid_size(config);
std::array<char, 4> keys{'M', 'N', 'K', 'O'};
// config (m0, n0, k0, n1)
std::array<std::size_t, 4> config{
c_shape.lens()[rank - 2], b_shape.lens()[rank - 2], a_shape.lens().back(), c_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, b1_shape, c_shape}));
auto ip = instance{get_gsg_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
})};
auto padding = ip.get_pad(config);
std::string gemm_type;
for(auto i : range(padding.size()))
{
if(padding[i] != 0)
gemm_type += keys[i];
}
if(gemm_type.empty())
gemm_type = "Default";
else
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
hip_compile_options options;
auto block_size = 256; // ip.get_block_size();
auto block_size = ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
......@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"instance", "" /* ip.str() */},
{{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
......@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
......@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes[2], shapes.back()};
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl;
}
......
This diff is collapsed.
......@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[0]; // col-major
constexpr const auto n = b_shape.lens[1];
constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1;
......@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0]; // row-major
constexpr const auto n1 = b1_shape.lens[1]; // row-major
constexpr const auto sb1 = b1_shape.strides[0]; // rowl-major
constexpr const auto k1 = b1_shape.lens[0];
constexpr const auto n1 = b1_shape.lens[1];
constexpr const auto sb1 = b1_shape.strides[0]; // row-major
constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1;
......@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map));
// static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// b1_grid_desc_bk0_n_bk1,
// c_grid_desc_m_n,
// block_2_ctile_map));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()),
......
......@@ -56,6 +56,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
......@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
fuse_mlir{&ctx},
......
......@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
migraphx::program create_program() const
{
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {16, 12, 384, 64}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {16, 12, 384, 384}};
// auto m2_elements = 16 * 12 * 384 * 384;
// auto a = mm->add_parameter("1", m1_shape);
// auto b = mm->add_parameter("2", m1_shape);
// auto b1 = mm->add_parameter("3", m1_shape);
// auto c = mm->add_parameter("4", m1_shape);
// std::vector<float> eights(m2_elements, 0.125);
// auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
// std::vector<float> zeros(m2_elements, 0);
// auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
// std::vector<float> ones(m2_elements, 1);
// auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = 1 * 12 * 256 * 256;
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
size_t batch = 2;
migraphx::shape m1_shape{migraphx::shape::half_type, {batch, 384, 2304}};
migraphx::shape m2_shape{migraphx::shape::half_type, {batch, 12, 384, 384}};
auto m2_elements = batch * 12 * 384 * 384;
auto g = mm->add_parameter("1", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
......@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
std::vector<float> ones(m2_elements, 1);
auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
g = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {batch, 384, 36, 64}}}), g);
g = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), g);
auto a = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), g);
auto b = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), g);
auto b1 = mm->add_instruction(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}), g);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
......
import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function = -1
@contextlib.contextmanager
def tmp_file(dump=None):
......@@ -20,6 +21,9 @@ def pretty_print(obj):
def run_driver(b):
print(b)
outfile = open("temp2.json", "w")
json.dump(b, outfile)
outfile.close()
with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True,
......@@ -45,7 +49,7 @@ def get_device_time(s):
def benchmark_ck(config, tuning):
try:
b = {
b0 = {
'settings': {
'iterations': 100
},
......@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning):
'inputs': config
}
}
b1 = {
'settings': {
'iterations': 100
},
'compile_op': {
'name': 'ck_gemm_softmax_gemm',
'check': True,
'tuning_val': tuning,
'inputs': config
}
}
b = b0 if (ck_function == 0) else b1
for line in run_driver(b):
dtime = get_device_time(line)
print(dtime)
......@@ -72,17 +88,26 @@ def benchmark(config, size):
def parse_log(f):
for line in open(f).readlines():
line = line.strip()
if not line.startswith('ck_gemm:'):
continue
line = line[len('ck_gemm:'):].strip()
config = json.loads(line)
yield config
global ck_function
if line.startswith('ck_gemm:'):
line = line[len('ck_gemm:'):].strip()
config = json.loads(line)
ck_function = 0
yield config
if line.startswith('ck_gemm_softmax_gemm:'):
line = line[len('ck_gemm_softmax_gemm:'):].strip()
config = json.loads(line)
ck_function = 1
yield config
def benchmark_log(f, n):
result = []
for config in parse_log(f):
tuned = benchmark(config, n)
logs = parse_log(f)
for config in logs:
additional_tv = ck_function * 2
tuned = benchmark(config, n + additional_tv)
print("Tuned:", tuned)
result.append([config, tuned])
return result
......
#!/bin/bash
MODEL=$1
LOG="ck_bbc.log"
TUNING_DB="ck_bbc.json"
rm $LOG
touch $LOG
for N in 1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM=1 ./bin/driver run $MODEL -g --fill1 input_ids --input-dim @input_ids $N 384 | grep 'ck_gemm.*: \[{' | sort -u >> $LOG
done
python3 ../tools/tune_ck.py -n 16 -l $LOG -o $TUNING_DB
\ No newline at end of file
import os, json, subprocess, tempfile, sys, argparse, contextlib
@contextlib.contextmanager
def tmp_file(dump=None):
tmp_name = None
try:
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
tmp_name = f.name
if dump:
dump(f)
yield tmp_name
finally:
os.unlink(tmp_name)
def pretty_print(obj):
print(json.dumps(obj, indent=2))
def run_driver(b):
print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True,
check=True,
shell=True)
for line in cp.stdout.decode().split("\n"):
s = line.strip()
if not s:
continue
if not ']: ' in s:
continue
yield s.split(']: ')[1].strip()
def convert_to_float(s):
return s[:-2]
def get_device_time(s):
fields = s.split(',')
return convert_to_float(fields[-1].strip())
def benchmark_ck(config, tuning):
try:
b = {
'settings': {
'iterations': 100
},
'compile_op': {
'name': 'ck_gemm_softmax_gemm',
'check': True,
'tuning_val': tuning,
'inputs': config
}
}
for line in run_driver(b):
dtime = get_device_time(line)
print(dtime)
return float(dtime)
except:
return sys.float_info.max
def benchmark(config, size):
times = [benchmark_ck(config, i) for i in range(size)]
return times.index(min(times))
def parse_log(f):
for line in open(f).readlines():
line = line.strip()
if not line.startswith('ck_gemm_softmax_gemm:'):
continue
line = line[len('ck_gemm_softmax_gemm:'):].strip()
config = json.loads(line)
yield config
def benchmark_log(f, n):
result = []
for config in parse_log(f):
tuned = benchmark(config, n)
print("Tuned:", tuned)
result.append([config, tuned])
return result
def parse_args():
parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
parser.add_argument('--log',
'-l',
type=str,
metavar='file',
help='Path to logfile')
parser.add_argument('--out',
'-o',
type=str,
metavar='file',
help='Output json file to save tunings')
parser.add_argument('-n', type=int, help='Number of instances to tune')
args = parser.parse_args()
return args
def run(args):
tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+'))
run(parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment