rank_worker.hpp 4.9 KB
Newer Older
1
2
#pragma once

3
#include "../backends/attention_backends.hpp"
4
#include "../cache/cache.hpp"
5
#include "../config/model_config.hpp"
6
#include "../models/model_factory.hpp"
7
#include "compiler/general_compiler.hpp"
8
#include "distributed/distributed.hpp"
9
#include "rank_barrier.hpp"
10
11
12
13

#include <any>
#include <condition_variable>
#include <mutex>
14
#include <random>
15
16
17
18
19
20
21
22
23
24
25
#include <string>
#include <thread>
#include <vector>

namespace infinilm::engine {

class RankWorker {
    enum class Command {
        INIT,
        LOAD,
        RUN,
Ceng's avatar
Ceng committed
26
        RESET_CACHE,
27
        COMPILE,
28
29
30
31
        STOP
    };

public:
32
33
34
35
36
37
    struct Input {
        /// Token IDs tensor of shape `[batch, seq_len]`.
        std::optional<infinicore::Tensor> input_ids;
        /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
        std::optional<infinicore::Tensor> position_ids;
        /// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
38
39
40
        std::optional<infinicore::Tensor> past_sequence_lengths;
        /// ToTal Lengths for each request sequence, of shape `[num_requests]`.
        std::optional<infinicore::Tensor> total_sequence_lengths;
41
        /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
42
        std::optional<infinicore::Tensor> input_offsets;
43
44
        /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`.
        std::optional<infinicore::Tensor> cu_seqlens;
45
46
47
48
49
50
51
52
53
54
55
        /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
        std::optional<infinicore::Tensor> block_tables;
        /// Slot ids for each token `[seq]`. Used for paged cache.
        std::optional<infinicore::Tensor> slot_mapping;

        float temperature{1};

        int top_k{50};

        float top_p{1};

yaoht's avatar
yaoht committed
56
57
58
        /// Fills max_seqlen_q/k for paged FA prefill using tensor shapes + cache config only (no tensor D2H).
        infinilm::InfinilmModel::Input to_model_input(infinicore::Device device,
                                                      const cache::CacheConfig *cache_config = nullptr) const;
59
60
61
62
63
64
    };

    struct Output {
        infinicore::Tensor output_ids;
    };

Jiacheng Huang's avatar
Jiacheng Huang committed
65
    RankWorker(const InfinilmModel::Config &model_config,
66
               const distributed::RankInfo &rank_info,
67
               const cache::CacheConfig *cache_config,
68
               RankBarrier *barrier,
69
70
               bool enable_graph_compiling,
               backends::AttentionBackend attention_backend);
71

72
73
74
75
    RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
               const distributed::RankInfo &rank_info,
               const cache::CacheConfig *cache_config,
               RankBarrier *barrier,
76
77
               bool enable_graph_compiling,
               backends::AttentionBackend attention_backend);
78

79
80
81
82
    // Submit a parameter load job and wait until the load completes on the worker thread.
    void load_param(const std::string &name,
                    const infinicore::Tensor &param);

83
84
85
    // return the parameters (i.e. weights and biases).
    std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();

PanZezhong's avatar
PanZezhong committed
86
    // Submit a run (forward) job.
87
    void run(const Input &args);
88

89
    // Reset the internal cache with a new configuration
PanZezhong's avatar
PanZezhong committed
90
    void reset_cache(const cache::CacheConfig *new_config);
Ceng's avatar
Ceng committed
91

92
93
94
    // Compile the model graph if enabled.
    void compile();

PanZezhong's avatar
PanZezhong committed
95
96
97
    // Wait until run job completes. The result can be retrieved with get_output().
    void wait();

98
99
100
101
    // Request worker shutdown and join the thread.
    void close();

    // Thread-safe accessor for last output produced by RUN.
102
    Output get_output();
103
104
105
106
107
108
109
110

    std::string info() const;

private:
    void thread_loop();

private:
    // Worker properties
111
112
    const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
    std::shared_ptr<infinilm::config::ModelConfig> model_config_;
113
114
    distributed::RankInfo rank_info_;
    std::shared_ptr<InfinilmModel> model_;
PanZezhong's avatar
PanZezhong committed
115
    std::shared_ptr<cache::Cache> cache_;
116

117
118
119
    // Backends
    backends::AttentionBackend attention_backend_;

120
121
122
123
    // Graph Compiling
    bool enable_graph_compiling_;
    std::unique_ptr<GraphCompiler> compiler_;

124
125
126
127
128
129
130
131
132
133
134
135
    // Command for the pending job (protected by mutex_)
    Command job_cmd_;

    // State flags (protected by mutex_)
    bool has_job_ = false;     // a job is pending
    bool job_done_ = false;    // last job completed
    bool should_exit_ = false; // request to stop
    bool init_done_ = false;   // initialization finished

    // Task payloads (protected by mutex)
    std::string pending_param_name_;
    infinicore::Tensor pending_param_;
136
    Input pending_args_;
PanZezhong's avatar
PanZezhong committed
137
    std::unique_ptr<cache::CacheConfig> pending_cache_config_;
138
139

    // Output (protected by mutex)
140
    Output output_;
141
142
143
144
145

    // Thread sync
    std::thread thread_;
    std::mutex mutex_;
    std::condition_variable cv_;
146
147
148

    // Random
    std::mt19937 rng_;
149
150

    RankBarrier *barrier_;
151
152
153
};

} // namespace infinilm::engine