Commit 9b3e4576 authored by Paul's avatar Paul
Browse files

Format

parent 74f21ca6
...@@ -269,22 +269,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -269,22 +269,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const bool can_fold_batch(const std::vector<shape>& inputs) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
return rank >= 3 and b_strides[rank - 3] == 0; return rank >= 3 and b_strides[rank - 3] == 0;
} }
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, const value& v) const ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
...@@ -318,19 +318,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -318,19 +318,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
} }
return ck::host::device_gemm_multiple_d::Problem{m, return ck::host::device_gemm_multiple_d::Problem{m,
n, n,
k, k,
transA, transA,
transB, transB,
transE, transE,
ds_layout, ds_layout,
a_type, a_type,
b_type, b_type,
e_type, e_type,
ds_type, ds_type,
ck_passthrough, ck_passthrough,
ck_passthrough, ck_passthrough,
cde_op}; cde_op};
} }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
...@@ -339,8 +339,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -339,8 +339,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v); auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a"); const auto solutions = problem.GetSolutions("gfx90a");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment