Commit 0a463c1e authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 8ab0b22e
......@@ -73,7 +73,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm
{
operation op = make_op("dot");
......@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return op.compute_shape({op.compute_shape({a, b}), b1});
}
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type}, t);
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
......@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return true;//k <= 2048;
return true; // k <= 2048;
}
struct find_ck_gemm_softmax_gemm
......
......@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; }
std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
static bool standard_batch(const shape& s)
{
......@@ -293,8 +296,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
ck::host::device_batched_gemm_softmax_gemm::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const
ck::host::device_batched_gemm_softmax_gemm::Problem
create_problem(const std::vector<shape>& inputs, const value& v) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
......@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);";
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
......@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl;
std::cout << "gpu::ck_gemm_softmax_gemm: "
<< to_json_string(to_value(gemm_shapes)) << std::endl;
}
m.replace_instruction(ins2, code_object, ins2->inputs());
}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment