infinilm_model.hpp 980 Bytes
Newer Older
1
2
3
4
#pragma once

#include "infinicore/nn/module.hpp"

5
6
#include "../cache/cache.hpp"

7
8
9
10
11
#include <any>

namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
Jiacheng Huang's avatar
Jiacheng Huang committed
12
13
14
15
16
17
    struct Config {
        std::string model_type;

        virtual ~Config() = default;
    };

18
19
20
21
22
    struct Input {
        /// Token IDs tensor of shape `[batch, seq_len]`.
        infinicore::Tensor input_ids;
        /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
        infinicore::Tensor position_ids;
PanZezhong's avatar
PanZezhong committed
23
24
        /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
        infinicore::Tensor cache_positions;
25
26
27
28
29
30
31
    };

    struct Output {
        /// Output tensor of shape [batch, seq_len, vocab_size].
        infinicore::Tensor logits;
    };

32
    virtual ~InfinilmModel() = default;
33
    virtual Output forward(const Input &input) const = 0;
PanZezhong's avatar
PanZezhong committed
34
35

    virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
36
37
};
} // namespace infinilm