You need to sign in or sign up before continuing.
static_batching_compiler.hpp 971 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
#pragma once

#include "graph_compiler.hpp"

#include <unordered_map>

namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
10
    StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    void compile() override;

    Compiled get_compiled(const InfinilmModel::Input &input) override;

private:
    struct TupleHash {
        size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
            auto h1 = std::hash<size_t>{}(std::get<0>(t));
            auto h2 = std::hash<size_t>{}(std::get<1>(t));
            return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
        }
    };

    struct CompiledResult {
        InfinilmModel::Input input;
        Compiled compiled;
    };

    std::unordered_map<
        std::tuple<size_t, size_t>, // (batch_size, seq_len)
        CompiledResult,
        TupleHash>
        compiled_map_;
};
} // namespace infinilm::engine