weights_loader.cpp 3.71 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
#include "weights_loader.hpp"
#include "infinicore_infer/weights_loader.h"

#include "../utils.hpp"

#include <infinirt.h>
7
#include <numeric>
blkmjsian's avatar
blkmjsian committed
8

9
namespace infinicore::weights {
blkmjsian's avatar
blkmjsian committed
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void Weight::load(const void *host_data, infinirtStream_t stream) {
    if (_dist_type == DistributionType::FULL) {
        _tensor->load(host_data, stream);
    } else if (_dist_type == DistributionType::ROW || _tensor->ndim() == 1) { // 1D column-distributed is same as row-distributed
        _tensor->load((const char *)host_data + _rank * _tensor->numel() * dsize(_tensor->dtype()), stream);
    } else if (_dist_type == DistributionType::COLUMN && _tensor->ndim() > 1) { // _dist_type == DistributionType::COLUMN
        void *rearranged_ptr;
        RUN_INFINI(infinirtMallocHost(&rearranged_ptr, _tensor->numel() * dsize(_tensor->dtype())));
        size_t row_size = _tensor->shape()[_tensor->ndim() - 1] * dsize(_tensor->dtype());
        size_t host_offset = _rank * row_size;
        size_t host_row_size = _nrank * row_size;
        size_t rows = std::accumulate(_tensor->shape().begin(), _tensor->shape().end() - 1, size_t(1), std::multiplies<size_t>());
        for (size_t row = 0; row < rows; row++) {
            memcpy((char *)rearranged_ptr + row * row_size,
                   (char *)host_data + host_offset + row * host_row_size,
                   row_size);
        }
        _tensor->load(rearranged_ptr, stream);
        RUN_INFINI(infinirtFreeHost(rearranged_ptr));
    } else {
        std::cerr << "Unsupported distribution type: " << _dist_type << std::endl;
        std::abort();
    }
};

Loader::Loader(infiniDevice_t dev, const std::vector<int> &dev_ids) : _device(dev), _dev_ids(dev_ids) {
blkmjsian's avatar
blkmjsian committed
37
    _streams.resize(_dev_ids.size());
38
    _weights_maps.resize(_dev_ids.size());
blkmjsian's avatar
blkmjsian committed
39
40
    for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
        RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
41
        _weights_maps[rank] = std::unordered_map<std::string, std::shared_ptr<Weight>>();
blkmjsian's avatar
blkmjsian committed
42
43
44
        RUN_INFINI(infinirtStreamCreate(&_streams[rank]));
    }
}
PanZezhong1725's avatar
PanZezhong1725 committed
45
void Loader::register_weight(const std::string &name, std::shared_ptr<Tensor> tensor, int rank, DistributionType dist_type) {
46
    _weights_maps[rank][name] = std::make_shared<Weight>(tensor, rank, _dev_ids.size(), dist_type);
blkmjsian's avatar
blkmjsian committed
47
}
48
void Loader::load(const std::string &name, const void *host_data) {
blkmjsian's avatar
blkmjsian committed
49
50
    for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
        RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
51
52
        auto it = _weights_maps[rank].find(name);
        if (it == _weights_maps[rank].end()) {
blkmjsian's avatar
blkmjsian committed
53
54
55
56
            std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
            std::abort();
        }

57
        _weights_maps[rank][name]->load(host_data, _streams[rank]);
blkmjsian's avatar
blkmjsian committed
58
59
60
61
62
63
    }
    for (int rank = int(_dev_ids.size() - 1); rank >= 0; rank--) {
        RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
        RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
    }
}
64
65

void Loader::finalize() {
blkmjsian's avatar
blkmjsian committed
66
67
68
69
70
71
72
73
74
    int dev_id;
    RUN_INFINI(infinirtGetDevice(nullptr, &dev_id));
    for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
        RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
        RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
        RUN_INFINI(infinirtStreamDestroy(_streams[rank]));
    }
    RUN_INFINI(infinirtSetDevice(_device, dev_id));
}
75
76
std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
    return _weights_maps[rank][name]->tensor();
blkmjsian's avatar
blkmjsian committed
77
78
}

79
} // namespace infinicore::weights
blkmjsian's avatar
blkmjsian committed
80

81
__INFINI_C void
blkmjsian's avatar
blkmjsian committed
82
83
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
    std::string name_str(name);
84
85
    auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_);
    weights->load(name_str, data);
blkmjsian's avatar
blkmjsian committed
86
}