infinilm_model.hpp 1.79 KB
Newer Older
1
2
#pragma once

3
#include "../cache/cache.hpp"
4
5
#include "infinicore/nn/module.hpp"
#include "nlohmann/json.hpp"
6

7
8
#include <any>

9
10
#include <optional>

11
12
13
namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
Jiacheng Huang's avatar
Jiacheng Huang committed
14
15
16
17
18
    struct Config {
        std::string model_type;
        virtual ~Config() = default;
    };

19
20
    struct Input {
        /// Token IDs tensor of shape `[batch, seq_len]`.
21
        std::optional<infinicore::Tensor> input_ids;
22
        /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
23
        std::optional<infinicore::Tensor> position_ids;
PanZezhong's avatar
PanZezhong committed
24
        /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
25
26
27
        std::optional<infinicore::Tensor> past_sequence_lengths;
        /// ToTal Lengths for each request sequence, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> total_sequence_lengths;
PanZezhong's avatar
PanZezhong committed
28
        /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
29
        std::optional<infinicore::Tensor> input_offsets;
30
31
        /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
        std::optional<infinicore::Tensor> cu_seqlens;
32
33
34
35
        /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
        std::optional<infinicore::Tensor> block_tables;
        /// Slot ids for each token `[seq]`. Used for paged cache.
        std::optional<infinicore::Tensor> slot_mapping;
36
37
38
    };

    struct Output {
39
        /// Logits.
40
41
42
        infinicore::Tensor logits;
    };

43
    virtual ~InfinilmModel() = default;
44
    virtual Output forward(const Input &input) const = 0;
PanZezhong's avatar
PanZezhong committed
45
46

    virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
47
    virtual const cache::CacheConfig *get_cache_config() const = 0;
48
49
};
} // namespace infinilm