rank_worker.hpp 2.35 KB
Newer Older
1
2
#pragma once

3
#include "../cache/cache.hpp"
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include "../models/model_factory.hpp"
#include "distributed/distributed.hpp"

#include <any>
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

namespace infinilm::engine {

class RankWorker {
    enum class Command {
        INIT,
        LOAD,
        RUN,
Ceng's avatar
Ceng committed
21
        RESET_CACHE,
22
23
24
25
        STOP
    };

public:
Jiacheng Huang's avatar
Jiacheng Huang committed
26
    RankWorker(const InfinilmModel::Config &model_config,
27
               const distributed::RankInfo &rank_info,
PanZezhong's avatar
PanZezhong committed
28
               const cache::CacheConfig *cache_config);
29
30
31
32
33

    // Submit a parameter load job and wait until the load completes on the worker thread.
    void load_param(const std::string &name,
                    const infinicore::Tensor &param);

34
35
36
    // return the parameters (i.e. weights and biases).
    std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();

PanZezhong's avatar
PanZezhong committed
37
    // Submit a run (forward) job.
38
    void run(const InfinilmModel::Input &args);
39

40
    // Reset the internal cache with a new configuration
PanZezhong's avatar
PanZezhong committed
41
    void reset_cache(const cache::CacheConfig *new_config);
Ceng's avatar
Ceng committed
42

PanZezhong's avatar
PanZezhong committed
43
44
45
    // Wait until run job completes. The result can be retrieved with get_output().
    void wait();

46
47
48
49
    // Request worker shutdown and join the thread.
    void close();

    // Thread-safe accessor for last output produced by RUN.
50
    InfinilmModel::Output get_output();
51
52
53
54
55
56
57
58

    std::string info() const;

private:
    void thread_loop();

private:
    // Worker properties
Jiacheng Huang's avatar
Jiacheng Huang committed
59
    const InfinilmModel::Config &model_config_;
60
61
    distributed::RankInfo rank_info_;
    std::shared_ptr<InfinilmModel> model_;
PanZezhong's avatar
PanZezhong committed
62
    std::shared_ptr<cache::Cache> cache_;
63
64
65
66
67
68
69
70
71
72
73
74
75

    // Command for the pending job (protected by mutex_)
    Command job_cmd_;

    // State flags (protected by mutex_)
    bool has_job_ = false;     // a job is pending
    bool job_done_ = false;    // last job completed
    bool should_exit_ = false; // request to stop
    bool init_done_ = false;   // initialization finished

    // Task payloads (protected by mutex)
    std::string pending_param_name_;
    infinicore::Tensor pending_param_;
76
    InfinilmModel::Input pending_args_;
PanZezhong's avatar
PanZezhong committed
77
    std::unique_ptr<cache::CacheConfig> pending_cache_config_;
78
79

    // Output (protected by mutex)
80
    InfinilmModel::Output output_;
81
82
83
84
85
86
87
88

    // Thread sync
    std::thread thread_;
    std::mutex mutex_;
    std::condition_variable cv_;
};

} // namespace infinilm::engine