Commit cfec69d3 authored by Paul's avatar Paul
Browse files

Fuse batched gemms

parent e3d57f93
...@@ -25,7 +25,7 @@ struct ck_gemm ...@@ -25,7 +25,7 @@ struct ck_gemm
void check_gemm_shape(const shape& s) const void check_gemm_shape(const shape& s) const
{ {
if(contains(s.lens(), 1)) if(not contains(range(s.strides().rbegin(), s.strides().rbegin()+3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm"); MIGRAPHX_THROW("Invalid shape for ck_gemm");
} }
...@@ -53,9 +53,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -53,9 +53,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2) if(a.lens().back() > 2048)
return false;
if(a.lens()[1] > 2048)
return false; return false;
return true; return true;
} }
...@@ -82,6 +80,8 @@ struct find_ck_gemm ...@@ -82,6 +80,8 @@ struct find_ck_gemm
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if (ins->get_shape().type() != shape::half_type)
return;
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
......
...@@ -199,9 +199,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -199,9 +199,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto rank = a_shape.lens().size();
std::array<char, 3> keys{'M', 'N', 'K'}; std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{ std::array<std::size_t, 3> config{
c_shape.lens().front(), c_shape.lens().back(), a_shape.lens().back()}; c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape})); auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool { auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
...@@ -231,10 +233,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -231,10 +233,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config); auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(
c_shape.lens().rbegin() + 2, c_shape.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
hip_compile_options options; hip_compile_options options;
auto block_size = ip.get_block_size(); auto block_size = ip.get_block_size();
auto grid_size = blocks_per_batch; auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment