Commit 37939805 authored by Alan Turner's avatar Alan Turner
Browse files

Move fuse_gsg to fuse_ck and fix bugs

parent 0a463c1e
...@@ -76,11 +76,12 @@ MIGRAPHX_REGISTER_OP(ck_gemm); ...@@ -76,11 +76,12 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm struct ck_gemm_softmax_gemm
{ {
operation op = make_op("dot"); operation op = make_op("dot");
double scale = 1.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op")); return pack(f(self.op, "op"), f(self.scale, "scale"));
} }
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; } std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
...@@ -91,7 +92,7 @@ struct ck_gemm_softmax_gemm ...@@ -91,7 +92,7 @@ struct ck_gemm_softmax_gemm
MIGRAPHX_THROW("Invalid shape for ck_gemm_softmax_gemm"); MIGRAPHX_THROW("Invalid shape for ck_gemm_softmax_gemm");
} }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2) if(inputs.size() < 2)
...@@ -136,38 +137,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -136,38 +137,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return true; // k <= 2048; return k <= 2048;
} }
struct find_ck_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
// if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
// return;
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
};
struct find_ck_gemm_pointwise struct find_ck_gemm_pointwise
{ {
// Find a gemm followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
...@@ -231,6 +203,74 @@ struct find_ck_gemm ...@@ -231,6 +203,74 @@ struct find_ck_gemm
} }
}; };
static bool is_mul_module(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
return is_mul_module(*ins.module_inputs().front());
}
else if(ins.name() == "mul")
{
return true;
}
}
return false;
}
struct find_ck_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"];
if (scale_ins->module_inputs().size() != 1 or not is_mul_module(*scale_ins->module_inputs().front()))
return;
if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return;
double scale = 1.0;
for (auto& in: scale_ins->inputs())
{
if (in->can_eval())
{
in->get_literal().visit([&](const auto s) {
if (std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
}
}
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs);
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
......
...@@ -58,6 +58,8 @@ static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__( ...@@ -58,6 +58,8 @@ static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp> #include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <${include}> #include <${include}>
namespace migraphx { namespace migraphx {
...@@ -69,7 +71,8 @@ extern "C" { ...@@ -69,7 +71,8 @@ extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(xs...); auto settings = make_ck_gemm_softmax_gemm_settings(MIGRAPHX_MAKE_CONSTANT(float{SCALE}));
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(settings, xs...);
}); });
} }
...@@ -158,6 +161,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -158,6 +161,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
std::cout << " " << inputs[0] << std::endl; std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl; std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl; std::cout << " " << inputs[2] << std::endl;
std::cout << " " << inputs[3] << std::endl;
} }
auto it = std::find_if( auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; }); tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
...@@ -167,6 +171,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -167,6 +171,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
std::cout << " " << inputs[0] << std::endl; std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl; std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl; std::cout << " " << inputs[2] << std::endl;
std::cout << " " << inputs[3] << std::endl;
std::vector<std::pair<float, std::size_t>> w; std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) { std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3) if(inputs.size() < 3 or p.first.size() < 3)
...@@ -181,7 +186,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -181,7 +186,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
return std::make_pair(avg_distance, p.second); return std::make_pair(avg_distance, p.second);
}); });
std::sort(w.begin(), w.end()); std::sort(w.begin(), w.end());
std::size_t default_value = 4; std::size_t default_value = 5;
if(not w.empty()) if(not w.empty())
default_value = w.front().second; default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value); auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
...@@ -322,12 +327,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -322,12 +327,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape); const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape); const auto c_type = get_type(c_shape);
const auto scale = 1.0f;
std::string ck_passthrough = "ck_passthrough"; std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
/// update params after adding to jitlib
return ck::host::device_batched_gemm_softmax_gemm::Problem{m, return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n, n,
k, k,
...@@ -343,19 +344,18 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -343,19 +344,18 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough};
scale};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
/// update for 4-arg lookup?
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value")) if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); tuning_value = get_tuning_for({a_shape, b_shape, b1_shape, c_shape});
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v); auto problem = create_problem(inputs, v);
...@@ -386,6 +386,11 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -386,6 +386,11 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
// scale
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
options.params += " -DSCALE=" + std::to_string(scale);
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel, auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"solution", template_str}, {{"solution", template_str},
{"include", include_header}, {"include", include_header},
...@@ -394,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -394,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -44,8 +44,20 @@ template <class Tensor> ...@@ -44,8 +44,20 @@ template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens), using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides))); ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class C, class A, class B, class B1> template <class T>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
template <class G, class C, class A, class B, class B1, class Settings>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1, Settings s)
{ {
constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(), constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<ck_transposeb<B>>(), to_ck_tensor<ck_transposeb<B>>(),
...@@ -53,19 +65,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -53,19 +65,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
to_ck_tensor<C>()); to_ck_tensor<C>());
static_assert(desc.IsValid(), "Invalid ck gemm."); static_assert(desc.IsValid(), "Invalid ck gemm.");
const float scale = s.scale;
G::Run(desc, G::Run(desc,
scale,
to_ck_const_pointer(a.data()), to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()), to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()), to_ck_const_pointer(b1.data()),
to_ck_pointer(c.data())); to_ck_pointer(c.data()));
} }
template <class G, index_int BlocksPerBatch, class... Ts> template <class G, index_int BlocksPerBatch, class... Ts, class Settings>
__device__ void ck_gemm_softmax_gemm(Ts... xs) __device__ void ck_gemm_softmax_gemm(Settings s, Ts... xs)
{ {
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)( gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)(
[](auto... ys) { ck_gemm_softmax_gemm_matrix<G>(ys...); }); [&](auto... ys) { ck_gemm_softmax_gemm_matrix<G>(ys..., s); });
} }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment