Commit ffe2c0cc authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 4b96da8d
......@@ -246,21 +246,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static std::size_t get_batch_count(const shape& s)
{
return std::accumulate(s.lens().rbegin() + 2,
s.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}
static void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if (lens.size() <= 2)
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if (transposed_matrix(s))
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
......@@ -269,7 +267,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if (lens.size() <= 2)
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
......@@ -332,7 +330,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs;
if (can_fold_batch)
if(can_fold_batch)
{
auto vinputs = inputs;
fold_batch_dims(vinputs[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment