weights_loader.hpp 1.83 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
#ifndef WEIGHTS_LOADER_HPP
#define WEIGHTS_LOADER_HPP

#include "../tensor.hpp"

#include <unordered_map>
#include <vector>

namespace infinicore {
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace weights {
enum DistributionType {
    FULL,
    ROW,
    COLUMN
};
class Weight {
private:
    std::shared_ptr<Tensor> _tensor;
    int _rank;
    int _nrank;
    DistributionType _dist_type;

public:
    Weight(std::shared_ptr<Tensor> tensor,
           int rank = 0,
           int nrank = 1,
           DistributionType dist_type = DistributionType::FULL)
        : _tensor(tensor), _rank(rank), _nrank(nrank), _dist_type(dist_type) {}
    std::shared_ptr<Tensor> tensor() const { return _tensor; }
    int rank() const { return _rank; }
    int nrank() const { return _nrank; }
    void load(const void *host_data, infinirtStream_t stream = nullptr);
};

class Loader {
blkmjsian's avatar
blkmjsian committed
37
protected:
38
    std::vector<std::unordered_map<std::string, std::shared_ptr<Weight>>> _weights_maps;
blkmjsian's avatar
blkmjsian committed
39
40
41
42
43
    infiniDevice_t _device;
    std::vector<int> _dev_ids;
    std::vector<infinirtStream_t> _streams;

public:
44
45
46
47
48
49
50
    Loader(infiniDevice_t, const std::vector<int> &dev_ids);

    /// @brief register a tensor to the loader
    /// @param name name (aka key) of the tensor
    /// @param tensor
    /// @param rank the rank of the weight tensor (default 0)
    /// @param dist_type either FULL, or distributed by ROW or COLUMN (default FULL)
PanZezhong1725's avatar
PanZezhong1725 committed
51
    void register_weight(const std::string &name, std::shared_ptr<Tensor> tensor, int rank = 0, DistributionType dist_type = DistributionType::FULL);
52
    void load(const std::string &name, const void *host_data);
blkmjsian's avatar
blkmjsian committed
53
54
    void finalize();
    std::shared_ptr<Tensor> get(const std::string &name, int rank = 0);
55
    const std::vector<int> &devIds() const { return _dev_ids; }
blkmjsian's avatar
blkmjsian committed
56
57
    infiniDevice_t device() const { return _device; }
};
58
} // namespace weights
blkmjsian's avatar
blkmjsian committed
59
60
61
} // namespace infinicore

#endif // WEIGHTS_LOADER_HPP