Commit 0a463c1e authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 8ab0b22e
...@@ -73,7 +73,6 @@ struct ck_gemm ...@@ -73,7 +73,6 @@ struct ck_gemm
}; };
MIGRAPHX_REGISTER_OP(ck_gemm); MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm struct ck_gemm_softmax_gemm
{ {
operation op = make_op("dot"); operation op = make_op("dot");
...@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm ...@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return op.compute_shape({op.compute_shape({a, b}), b1}); return op.compute_shape({op.compute_shape({a, b}), b1});
} }
static bool is_ck_supported_type(shape::type_t t) static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
{
return contains({shape::half_type}, t);
}
}; };
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm); MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
...@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return true;//k <= 2048; return true; // k <= 2048;
} }
struct find_ck_gemm_softmax_gemm struct find_ck_gemm_softmax_gemm
...@@ -163,7 +159,7 @@ struct find_ck_gemm_softmax_gemm ...@@ -163,7 +159,7 @@ struct find_ck_gemm_softmax_gemm
// if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type())) // if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
// return; // return;
auto inputs = gemm1_ins->inputs(); // A, B auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1 inputs.push_back(gemm2_ins->inputs().back()); // B1
......
...@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; } std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
static bool standard_batch(const shape& s) static bool standard_batch(const shape& s)
{ {
...@@ -293,13 +296,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -293,13 +296,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; }); b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
} }
ck::host::device_batched_gemm_softmax_gemm::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_batched_gemm_softmax_gemm::Problem
const value& v) const create_problem(const std::vector<shape>& inputs, const value& v) const
{ {
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2]; const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
// cppcheck-suppress unreadVariable // cppcheck-suppress unreadVariable
auto rank = a_shape.ndim(); auto rank = a_shape.ndim();
...@@ -311,37 +314,37 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -311,37 +314,37 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
auto o = c_shape.lens().back(); auto o = c_shape.lens().back();
const bool trans_a = transposed_matrix(a_shape); const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape); const bool trans_b = transposed_matrix(b_shape);
const bool trans_b1 = transposed_matrix(b1_shape); const bool trans_b1 = transposed_matrix(b1_shape);
const bool trans_c = transposed_matrix(c_shape); const bool trans_c = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape); const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape); const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape); const auto c_type = get_type(c_shape);
const auto scale = 1.0f; const auto scale = 1.0f;
std::string ck_passthrough = "ck_passthrough"; std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough; std::string cde_op = ck_passthrough;
/// update params after adding to jitlib /// update params after adding to jitlib
return ck::host::device_batched_gemm_softmax_gemm::Problem{m, return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n, n,
k, k,
o, o,
trans_a, trans_a,
trans_b, trans_b,
trans_b1, trans_b1,
trans_c, trans_c,
a_type, a_type,
b_type, b_type,
b1_type, b1_type,
c_type, c_type,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
scale}; scale};
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
...@@ -350,7 +353,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -350,7 +353,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
/// update for 4-arg lookup? /// update for 4-arg lookup?
auto tuning_value = v.get("tuning_value", 4); auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value")) if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
...@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"; "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>"; v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
...@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
std::vector<shape> gemm_shapes{ std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) std::cout << "gpu::ck_gemm_softmax_gemm: "
<< std::endl; << to_json_string(to_value(gemm_shapes)) << std::endl;
} }
m.replace_instruction(ins2, code_object, ins2->inputs()); m.replace_instruction(ins2, code_object, ins2->inputs());
}}; }};
......
...@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = 1 * 12 * 256 * 256; auto m2_elements = 1 * 12 * 256 * 256;
auto a = mm->add_parameter("1", m1_shape); auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape); auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape); auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape); auto c = mm->add_parameter("4", m1_shape);
std::vector<float> eights(m2_elements, 0.125); std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights}); auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0); std::vector<float> zeros(m2_elements, 0);
...@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm> ...@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto one = mm->add_literal(migraphx::literal{m2_shape, ones}); auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias); auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1); mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment