Commit 6af37d7c authored by Paul's avatar Paul
Browse files

Check batch is standard

parent 05e89700
...@@ -267,16 +267,29 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -267,16 +267,29 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s)
{
if (s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}
bool can_fold_batch(const std::vector<shape>& inputs) const bool can_fold_batch(const std::vector<shape>& inputs) const
{ {
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) { if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return input.broadcasted(); return not standard_batch(input);
})) }))
return false; return false;
const auto& b_strides = b_shape.strides(); const auto& b_strides = b_shape.strides();
return std::all_of( return std::all_of(
b_strides.begin(), b_strides.end() - 3, [](auto stride) { return stride == 0; }); b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
} }
ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs, ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector<shape>& inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment