Commit 0a463c1e authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 8ab0b22e
...@@ -73,7 +73,6 @@ struct ck_gemm ...@@ -73,7 +73,6 @@ struct ck_gemm
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm struct ck_gemm_softmax_gemm
{ {
operation op = make_op("dot"); operation op = make_op("dot");
...@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm ...@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return op.compute_shape({op.compute_shape({a, b}), b1}); return op.compute_shape({op.compute_shape({a, b}), b1});
} }
static bool is_ck_supported_type(shape::type_t t) static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
{
return contains({shape::half_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
...@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return true;//k <= 2048; return true; // k <= 2048;
} }
struct find_ck_gemm_softmax_gemm struct find_ck_gemm_softmax_gemm
......
...@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; } std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
static bool standard_batch(const shape& s) static bool standard_batch(const shape& s)
{ {
...@@ -293,8 +296,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -293,8 +296,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; }); b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
} }
ck::host::device_batched_gemm_softmax_gemm::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_batched_gemm_softmax_gemm::Problem
const value& v) const create_problem(const std::vector<shape>& inputs, const value& v) const
{ {
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
...@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"; "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>"; v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
...@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
std::vector<shape> gemm_shapes{ std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) std::cout << "gpu::ck_gemm_softmax_gemm: "
<< std::endl; << to_json_string(to_value(gemm_shapes)) << std::endl;
} }
m.replace_instruction(ins2, code_object, ins2->inputs()); m.replace_instruction(ins2, code_object, ins2->inputs());
}}; }};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment