"src/vscode:/vscode.git/clone" did not exist on "442cbc1bd008e7299831a0ce1e5710ff8ea5ccb8"
Commit 873f6c0c authored by Paul's avatar Paul
Browse files

Format

parent 6ad2af4e
...@@ -25,7 +25,7 @@ struct ck_gemm ...@@ -25,7 +25,7 @@ struct ck_gemm
void check_gemm_shape(const shape& s) const void check_gemm_shape(const shape& s) const
{ {
if (contains(s.lens(), 1)) if(contains(s.lens(), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm"); MIGRAPHX_THROW("Invalid shape for ck_gemm");
} }
...@@ -54,7 +54,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -54,7 +54,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if (a.lens().size() > 2 or b.lens().size() > 2) if(a.lens().size() > 2 or b.lens().size() > 2)
return false; return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0); b.lens()[1] % 8 == 0);
......
...@@ -90,8 +90,8 @@ static std::size_t get_block_size(const std::vector<std::string>& s) ...@@ -90,8 +90,8 @@ static std::size_t get_block_size(const std::vector<std::string>& s)
static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n) static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n)
{ {
auto mpb = std::stoull(s[block_size_index+1]); auto mpb = std::stoull(s[block_size_index + 1]);
auto npb = std::stoull(s[block_size_index+2]); auto npb = std::stoull(s[block_size_index + 2]);
return int_div_ceil(m, mpb) * int_div_ceil(n, npb); return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
} }
...@@ -99,12 +99,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -99,12 +99,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
{ {
if (s.type() == shape::half_type) if(s.type() == shape::half_type)
return "ck::half_t"; return "ck::half_t";
return shape::cpp_type(s.type()); return shape::cpp_type(s.type());
} }
...@@ -126,12 +127,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -126,12 +127,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
int i = v.get("tuning_val", 4); int i = v.get("tuning_val", 4);
const auto& instance = get_instance(i, [&](const auto& x) -> bool { const auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(b_shape) == x[1] and get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and
get_layout(c_shape) == x[2] and get_type(b_shape) == x[4] and get_type(c_shape) == x[5];
get_type(a_shape) == x[3] and
get_type(b_shape) == x[4] and
get_type(c_shape) == x[5];
}); });
hip_compile_options options; hip_compile_options options;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment