communication_group.hpp 1.35 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include "dist_config.hpp"

#include <infiniccl.h>
#include <infinicore/context/context.hpp>

8
#include <sstream>
9
10
11
12
13
#include <vector>

namespace infinilm::engine::distributed {

// Communicator each rank will hold
14
15
16
17
18
19
20
21
struct RankInfo {
    // Device Type and ID assigned to this rank
    infinicore::Device device;
    // Tensor parallelism size
    int tp_size;
    // Tensor parallelism rank number of this rank
    int tp_rank;
    // Communicator handle
22
    infinicclComm_t comm;
23
24
25
26
27
28
29
30
31

    RankInfo(infinicore::Device _device = infinicore::context::getDevice())
        : tp_size(1), tp_rank(0), device(_device), comm(nullptr){};

    std::string to_string() const {
        std::stringstream ss;
        ss << "RankInfo: device=" << device.toString() << ", tp_size=" << tp_size << ", tp_rank=" << tp_rank;
        return ss.str();
    }
32
33
34
35
36
};

// The communication group managed by model infer engine
class CommunicationGroup {
public:
37
38
39
    explicit CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type);

    const DistConfig &get_dist_config() const;
40

41
    RankInfo get_rank_info(int rank) const;
42

43
    int get_world_size() const;
44
45
46
47
48

    ~CommunicationGroup();

protected:
    DistConfig dist_config_;
49
    infinicore::Device::Type device_type_;
50
51
52
53
    std::vector<infinicclComm_t> communicators_;
};

} // namespace infinilm::engine::distributed