rank_worker.cpp 13.8 KB
Newer Older
1
2
3
4
#include "rank_worker.hpp"

#include "../models/model_factory.hpp"

5
6
#include "infinicore/ops.hpp"

7
8
9
10
11
12
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>

namespace infinilm::engine {

Jiacheng Huang's avatar
Jiacheng Huang committed
13
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
14
                       const distributed::RankInfo &rank_info,
15
                       const cache::CacheConfig *cache_config,
16
                       RankBarrier *barrier,
17
                       bool enable_graph_compiling)
18
19
    : model_config_(model_config),
      rank_info_(rank_info),
20
      enable_graph_compiling_(enable_graph_compiling),
21
22
23
24
      job_cmd_(Command::INIT),
      has_job_(false),
      job_done_(false),
      should_exit_(false),
25
      init_done_(false),
26
27
      rng_(std::random_device{}()),
      barrier_(barrier) {
PanZezhong's avatar
PanZezhong committed
28
29
30
    if (cache_config != nullptr) {
        pending_cache_config_ = cache_config->unique_copy();
    }
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    // start the thread
    thread_ = std::thread(&RankWorker::thread_loop, this);

    // Wait until the worker thread finishes initialization (model created)
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_; });
}

std::string RankWorker::info() const {
    std::stringstream ss;

    ss << "RankWorker{";

    // Rank related
    ss << rank_info_.to_string() << " ";

    // Flags
    ss << "| init_done: " << (init_done_ ? "true" : "false") << " ";
    ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " ";
    ss << "| has_job: " << (has_job_ ? "true" : "false") << " ";
    ss << "| job_done: " << (job_done_ ? "true" : "false") << " ";

    ss << "}";

    return ss.str();
}

//------------------------------------------------------
// load_param -- synchronous (blocks until worker finishes loading)
//------------------------------------------------------
void RankWorker::load_param(const std::string &name,
                            const infinicore::Tensor &param) {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        // If the worker is stopping, don't submit new jobs.
        if (should_exit_) {
            throw std::runtime_error("RankWorker is closing; cannot load_param");
        }

        pending_param_name_ = name;
        pending_param_ = param;

        job_cmd_ = Command::LOAD;
        has_job_ = true;
        job_done_ = false;
    }
    cv_.notify_all();

    // Wait for job completion
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
        throw std::runtime_error("RankWorker stopped while loading parameter");
    }
}

88
89
90
91
//------------------------------------------------------
// state_dict --
//------------------------------------------------------
std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dict() {
PanZezhong's avatar
PanZezhong committed
92
93
94
95
96
97
98
99
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return init_done_ || should_exit_; });

    if (!model_) {
        throw std::runtime_error("state_dict called before model initialization");
    }

    return model_->state_dict();
100
101
}

102
//------------------------------------------------------
PanZezhong's avatar
PanZezhong committed
103
// run -- asynchronous
104
//------------------------------------------------------
105
void RankWorker::run(const Input &args) {
PanZezhong's avatar
PanZezhong committed
106
    std::lock_guard<std::mutex> lock(mutex_);
107

PanZezhong's avatar
PanZezhong committed
108
109
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot run");
110
    }
PanZezhong's avatar
PanZezhong committed
111
112
113
114
115
116

    pending_args_ = args;
    job_cmd_ = Command::RUN;
    has_job_ = true;
    job_done_ = false;

117
    cv_.notify_all();
PanZezhong's avatar
PanZezhong committed
118
}
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//------------------------------------------------------
// compile -- asynchronous
//------------------------------------------------------
void RankWorker::compile() {
    std::lock_guard<std::mutex> lock(mutex_);
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot run");
    }

    job_cmd_ = Command::COMPILE;
    has_job_ = true;
    job_done_ = false;
    cv_.notify_all();
}

PanZezhong's avatar
PanZezhong committed
135
136
137
138
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
void RankWorker::wait() {
139
140
141
142
    std::unique_lock<std::mutex> lk(mutex_);
    cv_.wait(lk, [&] { return job_done_ || should_exit_; });

    if (should_exit_) {
PanZezhong's avatar
PanZezhong committed
143
        throw std::runtime_error("RankWorker stopped during run");
144
145
146
    }
}

PanZezhong's avatar
PanZezhong committed
147
void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
148
149
150
    std::lock_guard<std::mutex> lock(mutex_);
    if (should_exit_) {
        throw std::runtime_error("RankWorker is closing; cannot reset_cache");
Ceng's avatar
Ceng committed
151
    }
152
153

    // Store both the position and the new config
PanZezhong's avatar
PanZezhong committed
154
155
    pending_cache_config_ = new_config->unique_copy();
    job_cmd_ = Command::RESET_CACHE;
156
157
158
    has_job_ = true;
    job_done_ = false;
    cv_.notify_all();
Ceng's avatar
Ceng committed
159
160
}

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
void RankWorker::close() {
    {
        std::lock_guard<std::mutex> lock(mutex_);
        should_exit_ = true;
        has_job_ = false; // don't keep old jobs pending
        job_cmd_ = Command::STOP;
    }
    cv_.notify_all();

    if (thread_.joinable()) {
        thread_.join();
    }
}

//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
181
RankWorker::Output RankWorker::get_output() {
182
183
184
185
186
187
188
189
190
191
192
    std::lock_guard<std::mutex> lock(mutex_);
    return output_;
}

//------------------------------------------------------
// thread_loop
//------------------------------------------------------
void RankWorker::thread_loop() {
    try {
        {
            std::lock_guard<std::mutex> lk(mutex_);
PanZezhong's avatar
PanZezhong committed
193
194
195
196
197
198
199
200
201

            // Initialize device & model outside of holding the main mutex to avoid blocking callers.
            infinicore::context::setDevice(rank_info_.device);

            // Create model using factory (may be expensive)
            model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
            if (!model_) {
                throw std::runtime_error("Failed to create model");
            }
202
            if (enable_graph_compiling_) {
203
                compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
204
205
            }

206
207
208
209
210
211
212
213
214
            init_done_ = true;
        }
        cv_.notify_all();

        // Main loop: wait for jobs or exit
        while (true) {
            Command local_cmd = Command::INIT;
            std::string local_param_name;
            infinicore::Tensor local_param;
PanZezhong's avatar
PanZezhong committed
215
            Input local_args;
PanZezhong's avatar
PanZezhong committed
216
            std::unique_ptr<cache::CacheConfig> local_cache_config;
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

            // Wait for a job or exit
            {
                std::unique_lock<std::mutex> lk(mutex_);
                cv_.wait(lk, [&] { return has_job_ || should_exit_; });

                if (should_exit_) {
                    break;
                }

                // capture job data and clear has_job_
                local_cmd = job_cmd_;
                if (local_cmd == Command::LOAD) {
                    local_param_name = pending_param_name_;
                    local_param = pending_param_;
                } else if (local_cmd == Command::RUN) {
PanZezhong's avatar
PanZezhong committed
233
                    local_args = pending_args_;
Ceng's avatar
Ceng committed
234
                } else if (local_cmd == Command::RESET_CACHE) {
PanZezhong's avatar
PanZezhong committed
235
236
237
                    if (pending_cache_config_ != nullptr) {
                        local_cache_config = pending_cache_config_->unique_copy();
                    }
238
239
240
241
242
243
244
245
246
247
248
                }
                // mark job as being processed
                has_job_ = false;
                job_done_ = false;
            } // unlock mutex while executing the job

            // Execute job outside the lock
            if (local_cmd == Command::LOAD) {
                try {
                    model_->load_parameter(local_param_name, local_param);
                } catch (const std::exception &e) {
249
250
251
252
253
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                    cv_.notify_all();
                    spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
                    break;
                }

                // signal completion
                {
                    std::lock_guard<std::mutex> lk(mutex_);
                    job_done_ = true;
                }
                cv_.notify_all();

            } else if (local_cmd == Command::RUN) {
                try {
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                        infinicore::Tensor logits;
                        // Try to get compiled graph
                        if (compiler_ != nullptr) {
                            auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu()));
                            if (graph != nullptr && output != nullptr) {
                                graph->run();
                                logits = output->logits;
                            }
                        }
                        // Fall back to eager mode
                        if (!logits) {
                            auto model_args = local_args.to_model_input(rank_info_.device);
                            logits = model_->forward(model_args).logits;
                        }

PanZezhong's avatar
PanZezhong committed
286
                        // Random sampling (rank 0 only)
287
                        if (rank_info_.tp_rank == 0) {
PanZezhong's avatar
PanZezhong committed
288
289
290
                            auto temperature{local_args.temperature};
                            auto top_p{local_args.top_p};
                            auto top_k{local_args.top_k};
291
292
293

                            const auto &logits_shape{logits->shape()};
                            const auto &vocab_size{logits_shape[2]};
294
295
296
                            const auto &total_len{logits_shape[1]};
                            const auto &batch_size{logits_shape[0]};

PanZezhong's avatar
PanZezhong committed
297
298
                            auto n_req = local_args.input_offsets.value()->size(0) - 1;
                            int64_t *input_offsets = (int64_t *)local_args.input_offsets.value()->data();
299

300
                            auto output_ids{infinicore::Tensor::empty({n_req}, infinicore::DataType::I64, rank_info_.device)};
301

302
                            for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
PanZezhong's avatar
PanZezhong committed
303
                                auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})};
304
                                auto out{output_ids->narrow({{0, i, 1}})->view({})};
305
                                float random_val = std::uniform_real_distribution<float>(0, 1)(rng_);
306
307
308
309
310
311
312
313
314
315
316
317
318
                                infinicore::op::random_sample_(
                                    out, score, random_val, top_p, top_k, temperature);
                            }

                            output_ids = output_ids->to(infinicore::Device::cpu());

                            infinicore::context::syncStream();

                            auto out{Output{output_ids}};

                            output_ = std::move(out);
                        }

319
320
321
322
323
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
324
325
326
327
328
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
329
330
331
332
                    cv_.notify_all();
                    spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
                    break;
                }
Ceng's avatar
Ceng committed
333
334
            } else if (local_cmd == Command::RESET_CACHE) {
                try {
PanZezhong's avatar
PanZezhong committed
335
                    model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
336
337
338
339
340
341
342
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
343
344
345
346
347
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
348
349
350
351
352
353
                    cv_.notify_all();
                    spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
                    break;
                }
            } else if (local_cmd == Command::COMPILE) {
                try {
354
355
356
                    if (compiler_ != nullptr) {
                        compiler_->compile();
                    }
Ceng's avatar
Ceng committed
357
358
359
360
361
362
363
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        job_done_ = true;
                    }
                    cv_.notify_all();

                } catch (const std::exception &e) {
364
365
366
367
368
                    {
                        std::lock_guard<std::mutex> lk(mutex_);
                        should_exit_ = true;
                        job_done_ = true;
                    }
Ceng's avatar
Ceng committed
369
                    cv_.notify_all();
370
                    spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
Ceng's avatar
Ceng committed
371
372
                    break;
                }
373

374
375
376
377
            } else {
                // Shouldn't reach here (no-op)
            }
        } // while
378
379
380

        // Some clean up should be done before exiting the thread
        compiler_.reset();
381
382
383
384
385
386
387
388
389
390
391
392
393
    } catch (const std::exception &e) {
        // Top-level exception: ensure any waiters are woken and the thread exits cleanly.
        {
            std::lock_guard<std::mutex> lk(mutex_);
            should_exit_ = true;
            job_done_ = true;
        }
        cv_.notify_all();
        spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what());
    }
}

} // namespace infinilm::engine