rank_worker.cpp 11.6 KB
Newer Older
1
2
3
4
#include "rank_worker.hpp"

#include "../models/model_factory.hpp"

5
6
#include "infinicore/ops.hpp"

7
8
9
10
11
12
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>

namespace infinilm::engine {

Jiacheng Huang's avatar
Jiacheng Huang committed
13
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
14
                       const distributed::RankInfo &rank_info,
PanZezhong's avatar
PanZezhong committed
15
                       const cache::CacheConfig *cache_config)
16
17
18
19
20
21
    : model_config_(model_config),
      rank_info_(rank_info),
      job_cmd_(Command::INIT),
      has_job_(false),
      job_done_(false),
      should_exit_(false),
PanZezhong's avatar
PanZezhong committed
22
23
24
25
      init_done_(false) {
    if (cache_config != nullptr) {
        pending_cache_config_ = cache_config->unique_copy();
    }
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    // start the thread
    thread_ = std::thread(&RankWorker::thread_loop, this);

    // Wait until the worker thread finishes initialization (model created)
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_; });
}

std::string RankWorker::info() const {
    std::stringstream ss;

    ss << "RankWorker{";

    // Rank related
    ss << rank_info_.to_string() << " ";

    // Flags
    ss << "| init_done: " << (init_done_ ? "true" : "false") << " ";
    ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " ";
    ss << "| has_job: " << (has_job_ ? "true" : "false") << " ";
    ss << "| job_done: " << (job_done_ ? "true" : "false") << " ";

    ss << "}";

    return ss.str();
}

//------------------------------------------------------
// load_param -- synchronous (blocks until worker finishes loading)
//------------------------------------------------------
void RankWorker::load_param(const std::string &name,
                            const infinicore::Tensor &param) {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        // If the worker is stopping, don't submit new jobs.
        if (should_exit_) {
            throw std::runtime_error("RankWorker is closing; cannot load_param");
        }

        pending_param_name_ = name;
        pending_param_ = param;

        job_cmd_ = Command::LOAD;
        has_job_ = true;
        job_done_ = false;
    }
    cv_.notify_all();

    // Wait for job completion
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
        throw std::runtime_error("RankWorker stopped while loading parameter");
    }
}

83
84
85
86
//------------------------------------------------------
// state_dict --
//------------------------------------------------------
std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() {
PanZezhong's avatar
PanZezhong committed
87
88
89
90
91
92
93
94
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_ || should_exit_; });

    if (!model_) {
        throw std::runtime_error("state_dict called before model initialization");
    }

    return model_->state_dict();
95
96
}

97
//------------------------------------------------------
PanZezhong's avatar
PanZezhong committed
98
// run -- asynchronous
99
//------------------------------------------------------
100
void RankWorker::run(const Input &args) {
PanZezhong's avatar
PanZezhong committed
101
    std::lock_guard<std::mutex> lock(mutex_);
102

PanZezhong's avatar
PanZezhong committed
103
104
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot run");
105
    }
PanZezhong's avatar
PanZezhong committed
106
107
108
109
110
111

    pending_args_ = args;
    job_cmd_ = Command::RUN;
    has_job_ = true;
    job_done_ = false;

112
    cv_.notify_all();
PanZezhong's avatar
PanZezhong committed
113
}
114

PanZezhong's avatar
PanZezhong committed
115
116
117
118
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
void RankWorker::wait() {
119
120
121
122
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
PanZezhong's avatar
PanZezhong committed
123
        throw std::runtime_error("RankWorker stopped during run");
124
125
126
    }
}

PanZezhong's avatar
PanZezhong committed
127
void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
128
129
130
    std::lock_guard<std::mutex> lock(mutex_);
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot reset_cache");
Ceng's avatar
Ceng committed
131
    }
132
133

    // Store both the position and the new config
PanZezhong's avatar
PanZezhong committed
134
135
    pending_cache_config_ = new_config->unique_copy();
    job_cmd_ = Command::RESET_CACHE;
136
137
138
    has_job_ = true;
    job_done_ = false;
    cv_.notify_all();
Ceng's avatar
Ceng committed
139
140
}

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
void RankWorker::close() {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        should_exit_ = true;
        has_job_ = false; // don't keep old jobs pending
        job_cmd_ = Command::STOP;
    }
    cv_.notify_all();

    if (thread_.joinable()) {
        thread_.join();
    }
}

//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
161
RankWorker::Output RankWorker::get_output() {
162
163
164
165
166
167
168
169
170
171
172
    std::lock_guard<std::mutex> lock(mutex_);
    return output_;
}

//------------------------------------------------------
// thread_loop
//------------------------------------------------------
void RankWorker::thread_loop() {
    try {
        {
            std::lock_guard<std::mutex> lk(mutex_);
PanZezhong's avatar
PanZezhong committed
173
174
175
176
177
178
179
180
181

            // Initialize device & model outside of holding the main mutex to avoid blocking callers.
            infinicore::context::setDevice(rank_info_.device);

            // Create model using factory (may be expensive)
            model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
            if (!model_) {
                throw std::runtime_error("Failed to create model");
            }
182
183
184
185
186
187
188
189
190
            init_done_ = true;
        }
        cv_.notify_all();

        // Main loop: wait for jobs or exit
        while (true) {
            Command local_cmd = Command::INIT;
            std::string local_param_name;
            infinicore::Tensor local_param;
191
            InfinilmModel::Input local_args;
PanZezhong's avatar
PanZezhong committed
192
            std::unique_ptr<cache::CacheConfig> local_cache_config;
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

            // Wait for a job or exit
            {
                std::unique_lock<std::mutex> lk(mutex_);
                cv_.wait(lk, [&] { return has_job_ || should_exit_; });

                if (should_exit_) {
                    break;
                }

                // capture job data and clear has_job_
                local_cmd = job_cmd_;
                if (local_cmd == Command::LOAD) {
                    local_param_name = pending_param_name_;
                    local_param = pending_param_;
                } else if (local_cmd == Command::RUN) {
209
                    local_args = pending_args_.to_model_input(rank_info_.device);
Ceng's avatar
Ceng committed
210
                } else if (local_cmd == Command::RESET_CACHE) {
PanZezhong's avatar
PanZezhong committed
211
212
213
                    if (pending_cache_config_ != nullptr) {
                        local_cache_config = pending_cache_config_->unique_copy();
                    }
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                }
                // mark job as being processed
                has_job_ = false;
                job_done_ = false;
            } // unlock mutex while executing the job

            // Execute job outside the lock
            if (local_cmd == Command::LOAD) {
                try {
                    model_->load_parameter(local_param_name, local_param);
                } catch (const std::exception &e) {
                    // convert exceptions to a safe behavior: set should_exit_ and notify caller
                    std::lock_guard<std::mutex> lk(mutex_);
                    should_exit_ = true;
                    job_done_ = true;
                    cv_.notify_all();
                    // rethrow so the thread can be joined and caller sees an error if desired (optional)
                    spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
                    break;
                }

                // signal completion
                {
                    std::lock_guard<std::mutex> lk(mutex_);
                    job_done_ = true;
                }
                cv_.notify_all();

            } else if (local_cmd == Command::RUN) {
                try {
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
246
247
248
249
250
251
252
253
254
255
256
257

                        auto logits{model_->forward(local_args).logits};

                        if (rank_info_.tp_rank == 0) {
                            // Perform random sampling.
                            auto temperature{pending_args_.temperature};
                            auto top_p{pending_args_.top_p};
                            auto top_k{pending_args_.top_k};
                            auto random_val{pending_args_.random_val};

                            const auto &logits_shape{logits->shape()};
                            const auto &vocab_size{logits_shape[2]};
258
259
260
261
262
263
                            const auto &total_len{logits_shape[1]};
                            const auto &batch_size{logits_shape[0]};

                            auto n_req = pending_args_.input_offsets.value()->size(0);
                            int64_t *input_lengths = (int64_t *)pending_args_.input_lengths.value()->data();
                            int64_t *input_offsets = (int64_t *)pending_args_.input_offsets.value()->data();
264

265
                            auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
266

267
268
                            for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
                                auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i] + input_lengths[i] - 1), 1}})->view({vocab_size})};
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                                auto out{output_ids->narrow({{0, i, 1}})->view({})};
                                infinicore::op::random_sample_(
                                    out, score, random_val, top_p, top_k, temperature);
                            }

                            output_ids = output_ids->to(infinicore::Device::cpu());

                            infinicore::context::syncStream();

                            auto out{Output{output_ids}};

                            output_ = std::move(out);
                        }

283
284
285
286
287
288
289
290
291
292
293
294
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
                    std::lock_guard<std::mutex> lk(mutex_);
                    should_exit_ = true;
                    job_done_ = true;
                    cv_.notify_all();
                    spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
                    break;
                }
Ceng's avatar
Ceng committed
295
296
            } else if (local_cmd == Command::RESET_CACHE) {
                try {
PanZezhong's avatar
PanZezhong committed
297
                    model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
298

Ceng's avatar
Ceng committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
                    std::lock_guard<std::mutex> lk(mutex_);
                    should_exit_ = true;
                    job_done_ = true;
                    cv_.notify_all();
                    spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
                    break;
                }
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            } else {
                // Shouldn't reach here (no-op)
            }
        } // while
    } catch (const std::exception &e) {
        // Top-level exception: ensure any waiters are woken and the thread exits cleanly.
        {
            std::lock_guard<std::mutex> lk(mutex_);
            should_exit_ = true;
            job_done_ = true;
        }
        cv_.notify_all();
        spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what());
    }
}

} // namespace infinilm::engine