Commit 687a3310 authored by Paul's avatar Paul
Browse files

Format

parent 24f0cb5b
...@@ -170,10 +170,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -170,10 +170,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static bool transposed_matrix(const shape& s) static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
{
return s.strides().back() != 1;
}
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor" return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
...@@ -199,7 +196,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -199,7 +196,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
swap_inputs = false; swap_inputs = false;
auto c_shape = inputs.back(); auto c_shape = inputs.back();
if (not transposed_matrix(c_shape)) if(not transposed_matrix(c_shape))
return inputs; return inputs;
std::vector<int64_t> perm(c_shape.lens().size()); std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment