Commit 9b3e4576 authored by Paul's avatar Paul
Browse files

Format

parent 74f21ca6
......@@ -269,22 +269,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0;
}
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, const value& v) const
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides();
auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2];
......@@ -318,19 +318,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
}
return ck::host::device_gemm_multiple_d::Problem{m,
n,
k,
transA,
transB,
transE,
ds_layout,
a_type,
b_type,
e_type,
ds_type,
ck_passthrough,
ck_passthrough,
cde_op};
n,
k,
transA,
transB,
transE,
ds_layout,
a_type,
b_type,
e_type,
ds_type,
ck_passthrough,
ck_passthrough,
cde_op};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
......@@ -339,8 +339,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment