Commit ffe2c0cc authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 4b96da8d
...@@ -246,21 +246,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -246,21 +246,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static std::size_t get_batch_count(const shape& s) static std::size_t get_batch_count(const shape& s)
{ {
return std::accumulate(s.lens().rbegin() + 2, return std::accumulate(
s.lens().rend(), s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t{1},
std::multiplies<std::size_t>());
} }
static void fold_batch_dims(shape& s) static void fold_batch_dims(shape& s)
{ {
auto lens = s.lens(); auto lens = s.lens();
if (lens.size() <= 2) if(lens.size() <= 2)
return; return;
auto batch_count = get_batch_count(s); auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2); auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1); auto m2 = lens.at(lens.size() - 1);
if (transposed_matrix(s)) if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}}; s = shape{s.type(), {m1, m2 * batch_count}};
else else
s = shape{s.type(), {m1 * batch_count, m2}}; s = shape{s.type(), {m1 * batch_count, m2}};
...@@ -269,11 +267,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -269,11 +267,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static void remove_batch_dims(shape& s) static void remove_batch_dims(shape& s)
{ {
auto lens = s.lens(); auto lens = s.lens();
if (lens.size() <= 2) if(lens.size() <= 2)
return; return;
auto m1 = lens.at(lens.size() - 2); auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1); auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
...@@ -284,15 +282,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -284,15 +282,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0; bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0;
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
m = can_fold_batch ? m * batch_count : m; m = can_fold_batch ? m * batch_count : m;
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
std::array<char, 3> keys{'M', 'N', 'K'}; std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{m, n, k}; std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape})); auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
...@@ -332,7 +330,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -332,7 +330,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.output = c_shape; options.output = c_shape;
options.kernel_name = v.get("kernel", "ck_gemm_kernel"); options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
if (can_fold_batch) if(can_fold_batch)
{ {
auto vinputs = inputs; auto vinputs = inputs;
fold_batch_dims(vinputs[0]); fold_batch_dims(vinputs[0]);
......
...@@ -53,7 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) ...@@ -53,7 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
constexpr const auto b_grid_desc_n_k = constexpr const auto b_grid_desc_n_k =
gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>()); gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
constexpr const auto ds_grid_desc_m_n = constexpr const auto ds_grid_desc_m_n =
ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...); ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment