Commit 0a463c1e authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 8ab0b22e
......@@ -73,7 +73,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm
{
operation op = make_op("dot");
......@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return op.compute_shape({op.compute_shape({a, b}), b1});
}
static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type}, t);
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);
......@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return true;//k <= 2048;
return true; // k <= 2048;
}
struct find_ck_gemm_softmax_gemm
......@@ -163,7 +159,7 @@ struct find_ck_gemm_softmax_gemm
// if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
// return;
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
......
......@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; }
std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
static bool standard_batch(const shape& s)
{
......@@ -293,13 +296,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}
ck::host::device_batched_gemm_softmax_gemm::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const
ck::host::device_batched_gemm_softmax_gemm::Problem
create_problem(const std::vector<shape>& inputs, const value& v) const
{
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back();
const auto& c_shape = inputs.back();
// cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
......@@ -311,37 +314,37 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto k = a_shape.lens().back();
auto o = c_shape.lens().back();
const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape);
const bool trans_a = transposed_matrix(a_shape);
const bool trans_b = transposed_matrix(b_shape);
const bool trans_b1 = transposed_matrix(b1_shape);
const bool trans_c = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const bool trans_c = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto b1_type = get_type(b1_shape);
const auto c_type = get_type(c_shape);
const auto scale = 1.0f;
const auto c_type = get_type(c_shape);
const auto scale = 1.0f;
std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
/// update params after adding to jitlib
return ck::host::device_batched_gemm_softmax_gemm::Problem{m,
n,
k,
o,
trans_a,
trans_b,
trans_b1,
trans_c,
a_type,
b_type,
b1_type,
c_type,
ck_passthrough,
ck_passthrough,
ck_passthrough,
ck_passthrough,
scale};
n,
k,
o,
trans_a,
trans_b,
trans_b1,
trans_c,
a_type,
b_type,
b1_type,
c_type,
ck_passthrough,
ck_passthrough,
ck_passthrough,
ck_passthrough,
scale};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......@@ -350,7 +353,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back();
/// update for 4-arg lookup?
auto tuning_value = v.get("tuning_value", 4);
auto tuning_value = v.get("tuning_value", 4);
if(not v.contains("tuning_value"))
tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
......@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);";
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
......@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
std::vector<shape> gemm_shapes{
shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "gpu::ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl;
std::cout << "gpu::ck_gemm_softmax_gemm: "
<< to_json_string(to_value(gemm_shapes)) << std::endl;
}
m.replace_instruction(ins2, code_object, ins2->inputs());
}};
......
......@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = 1 * 12 * 256 * 256;
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
......@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment