communication_group.cpp 1.51 KB
Newer Older
1
2
3
4
5
#include "communication_group.hpp"
#include "../../utils.hpp"

namespace infinilm::engine::distributed {

6
7
CommunicationGroup::CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type)
    : dist_config_(dist_config), device_type_(device_type),
8
      communicators_(std::vector<infinicclComm_t>(dist_config.tp_device_ids.size(), nullptr)) {
9
10
11
    if (infinicore::context::getDevice().getType() != device_type_) {
        infinicore::context::setDevice(infinicore::Device(device_type_, 0));
    }
12
13
14
15
16
17
18
19
20
    if (dist_config_.tp_device_ids.size() > 1) {
        RUN_INFINI(infinicclCommInitAll(
            (infiniDevice_t)infinicore::context::getDevice().getType(),
            communicators_.data(),
            dist_config.tp_device_ids.size(),
            dist_config.tp_device_ids.data()));
    }
}

21
const DistConfig &CommunicationGroup::get_dist_config() const {
22
23
24
    return dist_config_;
}

25
26
27
28
29
30
31
32
33
34
35
RankInfo CommunicationGroup::get_rank_info(int rank) const {
    RankInfo info;
    info.tp_size = dist_config_.tp_device_ids.size();
    info.tp_rank = rank;
    info.device = infinicore::Device(device_type_, dist_config_.tp_device_ids[rank]);
    info.comm = communicators_[rank];
    return info;
}

int CommunicationGroup::get_world_size() const {
    return dist_config_.tp_device_ids.size();
36
37
38
39
40
}

CommunicationGroup::~CommunicationGroup() {
    if (communicators_.size() > 1) {
        for (auto &comm : communicators_) {
PanZezhong's avatar
PanZezhong committed
41
            infinicclCommDestroy(comm);
42
43
44
45
46
        }
    }
}

} // namespace infinilm::engine::distributed