#pragma once #include "dist_config.hpp" #include #include #include #include namespace infinilm::engine::distributed { // Communicator each rank will hold struct RankInfo { // Device Type and ID assigned to this rank infinicore::Device device; // Tensor parallelism size int tp_size; // Tensor parallelism rank number of this rank int tp_rank; // Communicator handle infinicclComm_t comm; RankInfo(infinicore::Device _device = infinicore::context::getDevice()) : tp_size(1), tp_rank(0), device(_device), comm(nullptr){}; std::string to_string() const { std::stringstream ss; ss << "RankInfo: device=" << device.toString() << ", tp_size=" << tp_size << ", tp_rank=" << tp_rank; return ss.str(); } }; // The communication group managed by model infer engine class CommunicationGroup { public: explicit CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type); const DistConfig &get_dist_config() const; RankInfo get_rank_info(int rank) const; int get_world_size() const; ~CommunicationGroup(); protected: DistConfig dist_config_; infinicore::Device::Type device_type_; std::vector communicators_; }; } // namespace infinilm::engine::distributed