Commit d294b663 authored by Paul's avatar Paul
Browse files

Improve the batch fold check

parent d8110fc4
...@@ -269,12 +269,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -269,12 +269,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
bool can_fold_batch(const std::vector<shape>& inputs) const bool can_fold_batch(const std::vector<shape>& inputs) const
{ {
const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
// cppcheck-suppress unreadVariable if(std::any_of(inputs.begin() + 2, inputs.end()-1, [](auto input) {
auto rank = a_shape.lens().size(); return input.broadcasted();
auto b_strides = b_shape.strides(); }))
return rank >= 3 and b_strides[rank - 3] == 0; return false;
const auto& b_strides = b_shape.strides();
return std::all_of(b_strides.begin(), b_strides.end() - 3, [](auto stride) {
return stride == 0;
});
} }
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment