"vllm/vscode:/vscode.git/clone" did not exist on "91a61da9b12c483a6688841b8f860c1a32b8918c"
infinilm_model.hpp 1.64 KB
Newer Older
1
2
#pragma once

3
#include "../cache/cache.hpp"
4
5
#include "infinicore/nn/module.hpp"
#include "nlohmann/json.hpp"
6

7
8
#include <any>

9
10
#include <optional>

11
12
13
namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
Jiacheng Huang's avatar
Jiacheng Huang committed
14
15
16
17
18
    struct Config {
        std::string model_type;
        virtual ~Config() = default;
    };

19
20
    struct Input {
        /// Token IDs tensor of shape `[batch, seq_len]`.
21
        std::optional<infinicore::Tensor> input_ids;
22
        /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
23
        std::optional<infinicore::Tensor> position_ids;
PanZezhong's avatar
PanZezhong committed
24
        /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
25
26
27
        std::optional<infinicore::Tensor> past_sequence_lengths;
        /// ToTal Lengths for each request sequence, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> total_sequence_lengths;
PanZezhong's avatar
PanZezhong committed
28
        /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
29
30
31
32
33
        std::optional<infinicore::Tensor> input_offsets;
        /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
        std::optional<infinicore::Tensor> block_tables;
        /// Slot ids for each token `[seq]`. Used for paged cache.
        std::optional<infinicore::Tensor> slot_mapping;
34
35
36
    };

    struct Output {
37
        /// Logits.
38
39
40
        infinicore::Tensor logits;
    };

41
    virtual ~InfinilmModel() = default;
42
    virtual Output forward(const Input &input) const = 0;
PanZezhong's avatar
PanZezhong committed
43
44

    virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
45
    virtual const cache::CacheConfig *get_cache_config() const = 0;
46
47
};
} // namespace infinilm