Commit 873f6c0c authored by Paul's avatar Paul
Browse files

Format

parent 6ad2af4e
......@@ -25,7 +25,7 @@ struct ck_gemm
void check_gemm_shape(const shape& s) const
{
if (contains(s.lens(), 1))
if(contains(s.lens(), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
......@@ -54,7 +54,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if (a.lens().size() > 2 or b.lens().size() > 2)
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
......
......@@ -90,8 +90,8 @@ static std::size_t get_block_size(const std::vector<std::string>& s)
static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n)
{
auto mpb = std::stoull(s[block_size_index+1]);
auto npb = std::stoull(s[block_size_index+2]);
auto mpb = std::stoull(s[block_size_index + 1]);
auto npb = std::stoull(s[block_size_index + 2]);
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
......@@ -99,12 +99,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
}
static std::string get_type(const shape& s)
{
if (s.type() == shape::half_type)
if(s.type() == shape::half_type)
return "ck::half_t";
return shape::cpp_type(s.type());
}
......@@ -117,21 +118,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1];
auto c_shape = inputs[2];
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto sa = a_shape.strides().front();
auto sb = b_shape.strides().front();
auto sc = c_shape.strides().front();
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
auto sa = a_shape.strides().front();
auto sb = b_shape.strides().front();
auto sc = c_shape.strides().front();
int i = v.get("tuning_val", 4);
int i = v.get("tuning_val", 4);
const auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and
get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[2] and
get_type(a_shape) == x[3] and
get_type(b_shape) == x[4] and
get_type(c_shape) == x[5];
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and
get_type(b_shape) == x[4] and get_type(c_shape) == x[5];
});
hip_compile_options options;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment