rank_worker.hpp 2.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#pragma once

#include "../models/model_factory.hpp"
#include "distributed/distributed.hpp"

#include <any>
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

namespace infinilm::engine {

class RankWorker {
    enum class Command {
        INIT,
        LOAD,
        RUN,
Ceng's avatar
Ceng committed
20
        RESET_CACHE,
21
22
23
24
25
26
27
28
29
30
31
        STOP
    };

public:
    RankWorker(const std::any &model_config,
               const distributed::RankInfo &rank_info);

    // Submit a parameter load job and wait until the load completes on the worker thread.
    void load_param(const std::string &name,
                    const infinicore::Tensor &param);

32
33
34
    // return the parameters (i.e. weights and biases).
    std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();

PanZezhong's avatar
PanZezhong committed
35
    // Submit a run (forward) job.
36
37
    void run(const std::vector<std::any> &args);

Ceng's avatar
Ceng committed
38
39
40
41
42
    // Reset the internal cache in the model (clears state between generations)
    // By default, this is synchronous (blocks until reset completes).
    // If async=true, this becomes asynchronous (unstable - use with caution).
    void reset_cache(size_t pos = 0, bool async = false);

PanZezhong's avatar
PanZezhong committed
43
44
45
    // Wait until run job completes. The result can be retrieved with get_output().
    void wait();

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    // Request worker shutdown and join the thread.
    void close();

    // Thread-safe accessor for last output produced by RUN.
    infinicore::Tensor get_output();

    std::string info() const;

private:
    void thread_loop();

private:
    // Worker properties
    std::any model_config_;
    distributed::RankInfo rank_info_;
    std::shared_ptr<InfinilmModel> model_;

    // Command for the pending job (protected by mutex_)
    Command job_cmd_;

    // State flags (protected by mutex_)
    bool has_job_ = false;     // a job is pending
    bool job_done_ = false;    // last job completed
    bool should_exit_ = false; // request to stop
    bool init_done_ = false;   // initialization finished

    // Task payloads (protected by mutex)
    std::string pending_param_name_;
    infinicore::Tensor pending_param_;
    std::vector<std::any> pending_args_;
Ceng's avatar
Ceng committed
76
    size_t pending_reset_pos_ = 0;
77
78
79
80
81
82
83
84
85
86
87

    // Output (protected by mutex)
    infinicore::Tensor output_;

    // Thread sync
    std::thread thread_;
    std::mutex mutex_;
    std::condition_variable cv_;
};

} // namespace infinilm::engine