"driver/olCompiling/include/logger.hpp" did not exist on "d2315b0dfcd6f31cca4328819eaf60d77e952dd6"
Commit 687a3310 authored by Paul's avatar Paul
Browse files

Format

parent 24f0cb5b
...@@ -170,14 +170,11 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs) ...@@ -170,14 +170,11 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static bool transposed_matrix(const shape& s) static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
{
return s.strides().back() != 1;
}
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor" return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor"; : "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
...@@ -197,9 +194,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -197,9 +194,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs) static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
{ {
swap_inputs = false; swap_inputs = false;
auto c_shape = inputs.back(); auto c_shape = inputs.back();
if (not transposed_matrix(c_shape)) if(not transposed_matrix(c_shape))
return inputs; return inputs;
std::vector<int64_t> perm(c_shape.lens().size()); std::vector<int64_t> perm(c_shape.lens().size());
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment