rank_worker.cpp 15.4 KB
Newer Older
1
2
3
4
#include "rank_worker.hpp"

#include "../models/model_factory.hpp"

5
6
#include "infinicore/ops.hpp"

7
8
9
10
11
12
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>

namespace infinilm::engine {

13
14
15
16
17
18
19
20
21
22
23
24
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
Jiacheng Huang's avatar
Jiacheng Huang committed
25
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
26
                       const distributed::RankInfo &rank_info,
27
                       const cache::CacheConfig *cache_config,
28
                       RankBarrier *barrier,
29
                       bool enable_graph_compiling)
30
    : legacy_model_config_(model_config),
31
      rank_info_(rank_info),
32
      enable_graph_compiling_(enable_graph_compiling),
33
34
35
36
      job_cmd_(Command::INIT),
      has_job_(false),
      job_done_(false),
      should_exit_(false),
37
      init_done_(false),
38
39
      rng_(std::random_device{}()),
      barrier_(barrier) {
PanZezhong's avatar
PanZezhong committed
40
41
42
    if (cache_config != nullptr) {
        pending_cache_config_ = cache_config->unique_copy();
    }
43
44
45
46
47
48
49
50
    // start the thread
    thread_ = std::thread(&RankWorker::thread_loop, this);

    // Wait until the worker thread finishes initialization (model created)
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_; });
}

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
RankWorker::RankWorker(
    std::shared_ptr<infinilm::config::ModelConfig> model_config,
    const distributed::RankInfo &rank_info,
    const cache::CacheConfig *cache_config,
    RankBarrier *barrier,
    bool enable_graph_compiling)
    : model_config_(model_config),
      rank_info_(rank_info),
      enable_graph_compiling_(enable_graph_compiling),
      job_cmd_(Command::INIT),
      has_job_(false),
      job_done_(false),
      should_exit_(false),
      init_done_(false),
      rng_(std::random_device{}()),
      barrier_(barrier) {
    if (cache_config != nullptr) {
        pending_cache_config_ = cache_config->unique_copy();
    }
    // start the thread
    thread_ = std::thread(&RankWorker::thread_loop, this);
    // Wait until the worker thread finishes initialization (model created)
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_; });
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
std::string RankWorker::info() const {
    std::stringstream ss;

    ss << "RankWorker{";

    // Rank related
    ss << rank_info_.to_string() << " ";

    // Flags
    ss << "| init_done: " << (init_done_ ? "true" : "false") << " ";
    ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " ";
    ss << "| has_job: " << (has_job_ ? "true" : "false") << " ";
    ss << "| job_done: " << (job_done_ ? "true" : "false") << " ";

    ss << "}";

    return ss.str();
}

//------------------------------------------------------
// load_param -- synchronous (blocks until worker finishes loading)
//------------------------------------------------------
void RankWorker::load_param(const std::string &name,
                            const infinicore::Tensor &param) {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        // If the worker is stopping, don't submit new jobs.
        if (should_exit_) {
            throw std::runtime_error("RankWorker is closing; cannot load_param");
        }

        pending_param_name_ = name;
        pending_param_ = param;

        job_cmd_ = Command::LOAD;
        has_job_ = true;
        job_done_ = false;
    }
    cv_.notify_all();

    // Wait for job completion
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
        throw std::runtime_error("RankWorker stopped while loading parameter");
    }
}

126
127
128
129
//------------------------------------------------------
// state_dict --
//------------------------------------------------------
std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() {
PanZezhong's avatar
PanZezhong committed
130
131
132
133
134
135
136
137
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_ || should_exit_; });

    if (!model_) {
        throw std::runtime_error("state_dict called before model initialization");
    }

    return model_->state_dict();
138
139
}

140
//------------------------------------------------------
PanZezhong's avatar
PanZezhong committed
141
// run -- asynchronous
142
//------------------------------------------------------
143
void RankWorker::run(const Input &args) {
PanZezhong's avatar
PanZezhong committed
144
    std::lock_guard<std::mutex> lock(mutex_);
145

PanZezhong's avatar
PanZezhong committed
146
147
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot run");
148
    }
PanZezhong's avatar
PanZezhong committed
149
150
151
152
153
154

    pending_args_ = args;
    job_cmd_ = Command::RUN;
    has_job_ = true;
    job_done_ = false;

155
    cv_.notify_all();
PanZezhong's avatar
PanZezhong committed
156
}
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
//------------------------------------------------------
// compile -- asynchronous
//------------------------------------------------------
void RankWorker::compile() {
    std::lock_guard<std::mutex> lock(mutex_);
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot run");
    }

    job_cmd_ = Command::COMPILE;
    has_job_ = true;
    job_done_ = false;
    cv_.notify_all();
}

PanZezhong's avatar
PanZezhong committed
173
174
175
176
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
void RankWorker::wait() {
177
178
179
180
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
PanZezhong's avatar
PanZezhong committed
181
        throw std::runtime_error("RankWorker stopped during run");
182
183
184
    }
}

PanZezhong's avatar
PanZezhong committed
185
void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
186
187
188
    std::lock_guard<std::mutex> lock(mutex_);
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot reset_cache");
Ceng's avatar
Ceng committed
189
    }
190
191

    // Store both the position and the new config
PanZezhong's avatar
PanZezhong committed
192
193
    pending_cache_config_ = new_config->unique_copy();
    job_cmd_ = Command::RESET_CACHE;
194
195
196
    has_job_ = true;
    job_done_ = false;
    cv_.notify_all();
Ceng's avatar
Ceng committed
197
198
}

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
void RankWorker::close() {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        should_exit_ = true;
        has_job_ = false; // don't keep old jobs pending
        job_cmd_ = Command::STOP;
    }
    cv_.notify_all();

    if (thread_.joinable()) {
        thread_.join();
    }
}

//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
219
RankWorker::Output RankWorker::get_output() {
220
221
222
223
224
225
226
227
228
229
230
    std::lock_guard<std::mutex> lock(mutex_);
    return output_;
}

//------------------------------------------------------
// thread_loop
//------------------------------------------------------
void RankWorker::thread_loop() {
    try {
        {
            std::lock_guard<std::mutex> lk(mutex_);
PanZezhong's avatar
PanZezhong committed
231
232
233
234
235

            // Initialize device & model outside of holding the main mutex to avoid blocking callers.
            infinicore::context::setDevice(rank_info_.device);

            // Create model using factory (may be expensive)
236
237
238
239
240
241
242
            if (model_config_ == nullptr) {
                model_ = InfinilmModelFactory::createModel(legacy_model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);

            } else {
                model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
            }

PanZezhong's avatar
PanZezhong committed
243
244
245
            if (!model_) {
                throw std::runtime_error("Failed to create model");
            }
246
            if (enable_graph_compiling_) {
247
                compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
248
249
            }

250
251
252
253
254
255
256
257
258
            init_done_ = true;
        }
        cv_.notify_all();

        // Main loop: wait for jobs or exit
        while (true) {
            Command local_cmd = Command::INIT;
            std::string local_param_name;
            infinicore::Tensor local_param;
PanZezhong's avatar
PanZezhong committed
259
            Input local_args;
PanZezhong's avatar
PanZezhong committed
260
            std::unique_ptr<cache::CacheConfig> local_cache_config;
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

            // Wait for a job or exit
            {
                std::unique_lock<std::mutex> lk(mutex_);
                cv_.wait(lk, [&] { return has_job_ || should_exit_; });

                if (should_exit_) {
                    break;
                }

                // capture job data and clear has_job_
                local_cmd = job_cmd_;
                if (local_cmd == Command::LOAD) {
                    local_param_name = pending_param_name_;
                    local_param = pending_param_;
                } else if (local_cmd == Command::RUN) {
PanZezhong's avatar
PanZezhong committed
277
                    local_args = pending_args_;
Ceng's avatar
Ceng committed
278
                } else if (local_cmd == Command::RESET_CACHE) {
PanZezhong's avatar
PanZezhong committed
279
280
281
                    if (pending_cache_config_ != nullptr) {
                        local_cache_config = pending_cache_config_->unique_copy();
                    }
282
283
284
285
286
287
288
289
290
291
292
                }
                // mark job as being processed
                has_job_ = false;
                job_done_ = false;
            } // unlock mutex while executing the job

            // Execute job outside the lock
            if (local_cmd == Command::LOAD) {
                try {
                    model_->load_parameter(local_param_name, local_param);
                } catch (const std::exception &e) {
293
294
295
296
297
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                    cv_.notify_all();
                    spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
                    break;
                }

                // signal completion
                {
                    std::lock_guard<std::mutex> lk(mutex_);
                    job_done_ = true;
                }
                cv_.notify_all();

            } else if (local_cmd == Command::RUN) {
                try {
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                        infinicore::Tensor logits;
                        // Try to get compiled graph
                        if (compiler_ != nullptr) {
                            auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu()));
                            if (graph != nullptr && output != nullptr) {
                                graph->run();
                                logits = output->logits;
                            }
                        }
                        // Fall back to eager mode
                        if (!logits) {
                            auto model_args = local_args.to_model_input(rank_info_.device);
                            logits = model_->forward(model_args).logits;
                        }

PanZezhong's avatar
PanZezhong committed
330
                        // Random sampling (rank 0 only)
331
                        if (rank_info_.tp_rank == 0) {
PanZezhong's avatar
PanZezhong committed
332
333
334
                            auto temperature{local_args.temperature};
                            auto top_p{local_args.top_p};
                            auto top_k{local_args.top_k};
335
336
337

                            const auto &logits_shape{logits->shape()};
                            const auto &vocab_size{logits_shape[2]};
338
339
340
                            const auto &total_len{logits_shape[1]};
                            const auto &batch_size{logits_shape[0]};

PanZezhong's avatar
PanZezhong committed
341
                            auto n_req = local_args.input_offsets.value()->size(0) - 1;
342
                            int32_t *input_offsets = (int32_t *)local_args.input_offsets.value()->data();
343

344
                            auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
345

346
                            for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
PanZezhong's avatar
PanZezhong committed
347
                                auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})};
348
                                auto out{output_ids->narrow({{0, i, 1}})->view({})};
349
                                float random_val = std::uniform_real_distribution<float>(0, 1)(rng_);
350
351
352
353
354
355
356
357
358
359
360
361
362
                                infinicore::op::random_sample_(
                                    out, score, random_val, top_p, top_k, temperature);
                            }

                            output_ids = output_ids->to(infinicore::Device::cpu());

                            infinicore::context::syncStream();

                            auto out{Output{output_ids}};

                            output_ = std::move(out);
                        }

363
364
365
366
367
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
368
369
370
371
372
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
373
374
375
376
                    cv_.notify_all();
                    spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
                    break;
                }
Ceng's avatar
Ceng committed
377
378
            } else if (local_cmd == Command::RESET_CACHE) {
                try {
PanZezhong's avatar
PanZezhong committed
379
                    model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
380
381
382
383
384
385
386
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
387
388
389
390
391
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
392
393
394
395
396
397
                    cv_.notify_all();
                    spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
                    break;
                }
            } else if (local_cmd == Command::COMPILE) {
                try {
398
399
400
                    if (compiler_ != nullptr) {
                        compiler_->compile();
                    }
Ceng's avatar
Ceng committed
401
402
403
404
405
406
407
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
408
409
410
411
412
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
Ceng's avatar
Ceng committed
413
                    cv_.notify_all();
414
                    spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
Ceng's avatar
Ceng committed
415
416
                    break;
                }
417

418
419
420
421
            } else {
                // Shouldn't reach here (no-op)
            }
        } // while
422
423
424

        // Some clean up should be done before exiting the thread
        compiler_.reset();
425
426
427
428
429
430
431
432
433
434
435
436
437
    } catch (const std::exception &e) {
        // Top-level exception: ensure any waiters are woken and the thread exits cleanly.
        {
            std::lock_guard<std::mutex> lk(mutex_);
            should_exit_ = true;
            job_done_ = true;
        }
        cv_.notify_all();
        spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what());
    }
}

} // namespace infinilm::engine