Commit cfec69d3 authored by Paul's avatar Paul
Browse files

Fuse batched gemms

parent e3d57f93
......@@ -25,7 +25,7 @@ struct ck_gemm
void check_gemm_shape(const shape& s) const
{
if(contains(s.lens(), 1))
if(not contains(range(s.strides().rbegin(), s.strides().rbegin()+3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
......@@ -53,9 +53,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
if(a.lens()[1] > 2048)
if(a.lens().back() > 2048)
return false;
return true;
}
......@@ -82,6 +80,8 @@ struct find_ck_gemm
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if (ins->get_shape().type() != shape::half_type)
return;
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
......
......@@ -199,9 +199,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto rank = a_shape.lens().size();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{
c_shape.lens().front(), c_shape.lens().back(), a_shape.lens().back()};
c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
......@@ -231,10 +233,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto blocks_per_batch = ip.get_grid_size(config);
auto batch_count = std::accumulate(
c_shape.lens().rbegin() + 2, c_shape.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
hip_compile_options options;
auto block_size = ip.get_block_size();
auto grid_size = blocks_per_batch;
auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment