"vscode:/vscode.git/clone" did not exist on "389bc830a3199fbc74407e152edb20e8b059869c"
Commit 37939805 authored by Alan Turner's avatar Alan Turner
Browse files

Move fuse_gsg to fuse_ck and fix bugs

parent 0a463c1e
......@@ -76,11 +76,12 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm
{
operation op = make_op("dot");
double scale = 1.0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
return pack(f(self.op, "op"), f(self.scale, "scale"));
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
......@@ -91,7 +92,7 @@ struct ck_gemm_softmax_gemm
MIGRAPHX_THROW("Invalid shape for ck_gemm_softmax_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
......@@ -136,38 +137,9 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return true; // k <= 2048;
return k <= 2048;
}
struct find_ck_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
// if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
// return;
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
};
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
......@@ -231,6 +203,74 @@ struct find_ck_gemm
}
};
static bool is_mul_module(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
return is_mul_module(*ins.module_inputs().front());
}
else if(ins.name() == "mul")
{
return true;
}
}
return false;
}
struct find_ck_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"];
if (scale_ins->module_inputs().size() != 1 or not is_mul_module(*scale_ins->module_inputs().front()))
return;
if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return;
double scale = 1.0;
for (auto& in: scale_ins->inputs())
{
if (in->can_eval())
{
in->get_literal().visit([&](const auto s) {
if (std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
}
}
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs);
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
......
......@@ -58,6 +58,8 @@ static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__(
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <${include}>
namespace migraphx {
......@@ -69,7 +71,8 @@ extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(xs...);
auto settings = make_ck_gemm_softmax_gemm_settings(MIGRAPHX_MAKE_CONSTANT(float{SCALE}));
ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(settings, xs...);
});
}
......@@ -158,6 +161,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
std::cout << " " << inputs[3] << std::endl;
}
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
......@@ -167,6 +171,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
std::cout << " " << inputs[3] << std::endl;
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
......@@ -181,7 +186,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
return std::make_pair(avg_distance, p.second);
});
std::sort(w.begin(), w.end());
std::size_t default_value = 4;
std::size_t default_value = 5;
if(not w.empty())
default_value = w.front().second;
auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
......@@ -322,12 +327,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape);
const auto scale = 1.0f;
std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
/// update params after adding to jitlib
return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n,
k,
......@@ -343,19 +344,18 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
ck_passthrough,
ck_passthrough,
ck_passthrough,
ck_passthrough,
scale};
ck_passthrough};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back();
/// update for 4-arg lookup?
auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
tuning_value = get_tuning_for({a_shape, b_shape, b1_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
......@@ -386,6 +386,11 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1";
// scale
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
options.params += " -DSCALE=" + std::to_string(scale);
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"solution", template_str},
{"include", include_header},
......@@ -394,7 +399,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
}
......
......@@ -44,8 +44,20 @@ template <class Tensor>
using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>{}.lens),
ck_transposeb_dims(get_shape_c<Tensor>{}.strides)));
template <class G, class C, class A, class B, class B1>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
template <class T>
struct ck_gemm_softmax_gemm_settings
{
T scale{};
};
template <class... Ts>
constexpr ck_gemm_softmax_gemm_settings<Ts...> make_ck_gemm_softmax_gemm_settings(Ts... xs)
{
return {xs...};
}
template <class G, class C, class A, class B, class B1, class Settings>
__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1, Settings s)
{
constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<ck_transposeb<B>>(),
......@@ -53,19 +65,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
to_ck_tensor<C>());
static_assert(desc.IsValid(), "Invalid ck gemm.");
const float scale = s.scale;
G::Run(desc,
scale,
to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()),
to_ck_pointer(c.data()));
}
template <class G, index_int BlocksPerBatch, class... Ts>
__device__ void ck_gemm_softmax_gemm(Ts... xs)
template <class G, index_int BlocksPerBatch, class... Ts, class Settings>
__device__ void ck_gemm_softmax_gemm(Settings s, Ts... xs)
{
gemm_batch_args(make_index(), _c<BlocksPerBatch>, xs...)(
[](auto... ys) { ck_gemm_softmax_gemm_matrix<G>(ys...); });
[&](auto... ys) { ck_gemm_softmax_gemm_matrix<G>(ys..., s); });
}
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment