#pragma once

#include <string.h>
#include "base.h"
#include "socket.h"
#include "bootstrap_utils.h"
#include "bootstrap_net.h"
#include "thread_pool.h"
#include "ipc_socket.h"
#include "physical_links.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

typedef sccl::hardware::net::ipc_socket::scclIpcSocket_t scclIpcSocket_t;
typedef physical_links::scclTopoNode_t scclTopoNode_t;

///////////////////////////////////// 用于初始化时的功能函数 //////////////////////////////////////////
scclResult_t bootstrapGetUniqueId(BootstrapHandle_t* handle);
scclResult_t bootstrapCreateRoot(BootstrapHandle_t* handle);

////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// 定义结构体 scclRankInfo，用于存储每个rank的通信节点的信息
// 每个rank必定选定一个GPU、一个RDMA设备、一个CPU，作为通信节点
typedef struct scclRankInfo {
    struct {
        scclSocket_t listen_sock; // 监听套接字
    } cpu;                        // CPU节点
    struct {
        int dev;           // NVML设备编号
        char name[8];      // 设备名称
        char gcn[8];       // GCN架构名称
        int compCap;       // CUDA计算能力
        int64_t pciBusId;  // PCI总线ID以int64_t格式表示
        char pciPath[128]; // PCI设备在/sys中的路径。
    } gpu;                 // GPU节点
    struct {
        int count;          // 网卡数量
        char name[8];       // 主要用于日志记录。
        char pciPath[128];  // PCI设备在/sys中的路径。
        uint64_t guid;      // NIC芯片的唯一标识符。对于具有多个PCI功能（物理或虚拟）的卡非常重要。
        uint8_t ptrSupport; // [SCCL_PTR_HOST|SCCL_PTR_CUDA|SCCL_PTR_DMABUF]
        int speed;          // 端口速度，单位为Mbps。
        int port;           // 端口号。
        float latency;      // 网络延迟
        int maxComms;       // 可以创建的最大通信数量
        int maxRecvs;       // 最大分组接收数量。
    } net;                  // 网络节点

    int rank      = -1; // 当前节点的全局排名
    int localRank = -1; // 当前节点在本地计算节点中的排名

    uint64_t hostHash = 0; // 主机哈希值
    uint64_t pidHash  = 0; // 进程 ID 哈希值
} scclRankInfo_t;

// 定义结构体 scclNodeInfo，用于存储每个rank的图连接信息
// TODO: 目前每个rank需要的node_info大小为4k+，当卡数较大时占用内存较大，可以优化。或者不作为全局变量
typedef struct scclNodeInfo {
    scclTopoNode_t* nodes; // 指向scclTopoNode_t对象数组的指针
    int nLocalRanks;
    int totalByteSize; // 表示占用的总字节数

    // 带参数的构造函数，用于初始化nodes的大小
    scclNodeInfo(int nLocalRanks) : nodes(nullptr), nLocalRanks(nLocalRanks), totalByteSize(sizeof(scclTopoNode_t) * topoNodeMaxLocalNodes / nLocalRanks) {
        nodes = reinterpret_cast<scclTopoNode_t*>(malloc(totalByteSize));
        if(nodes) {
            memset(nodes, 0, totalByteSize);
        }
    }

    // 析构函数，用于释放申请的数组空间
    virtual ~scclNodeInfo() {
        if(nodes) {
            free(nodes);
        }
    }
} scclNodeInfo_t;

// 所有节点的信息
typedef struct scclRankPhysSet {
    // 构造函数声明
    scclRankPhysSet(int nRanks, int nLocalRanks);

    std::vector<scclRankInfo_t> rank_info_vec;
    std::vector<char> node_info_vec; // 实际为std::vector<scclNodeInfo_t>，vector不支持scclNodeInfo_t变长

public:
    int nRanks                   = 0; // 总的节点数量
    int nLocalRanks              = 0; // 本地计算节点中的节点总数
    size_t node_info_total_bytes = 0; // 记录可变长度scclNodeInfo_t类型数据的实际大小
} scclRankPhysSet_t;

// BootstrapComm 结构体定义，用于存储引导通信信息
typedef struct BootstrapComm {
    void init(int rank, int nRanks, int localRank, int nLocalRanks);
    void destroy();

public:
    scclNet_t* scclNet               = nullptr;
    scclRankPhysSet_t* rank_phys_set = nullptr;

    cpu_set_t cpuAffinity; // CPU亲和性
    int rank        = -1;  // 当前节点的全局排名
    int nRanks      = 0;   // 总的节点数量
    int localRank   = -1;  // 当前节点在本地计算节点中的排名
    int nLocalRanks = 0;   // 本地计算节点中的节点总数
    int interRank   = -1;  // 整个节点在全部节点中的位置
    int nInterRanks = 0;   // 全局拥有节点的个数

    int hipDev    = -1; // CUDA 设备 ID
    int deviceCnt = 0;  // 设备数量

    uint64_t magic; // 魔术数，用于验证结构体
} BootstrapComm_t;

///////////////////////////////////// 用于初始化时的类 //////////////////////////////////////////
class Bootstrap {
public:
    Bootstrap(const BootstrapHandle_t*, int rank, int nRanks);
    virtual ~Bootstrap();

    // 初始化bootstrap通信环境
    scclResult_t init(BootstrapComm_t* bootstrap_comm);

    // 实现跨节点的AllGather通信操作
    scclResult_t bootstrapAllGather(const void* src_data, void* dst_data, int data_size);

private:
    // 执行根节点的聚集和广播操作
    scclResult_t bootstrapRootGatherAndBroadcast(BootstrapNodeBasic_t* send_data_basic);

    // 初始化节点通信信息
    scclResult_t bootstrapCommInitNodeInfo(scclNet_t* scclNet, scclRankInfo_t* rank_info);

    // 实现rank_info信息的节点间通信的AllGather操作
    scclResult_t bootstrapCommAllGather(scclRankInfo_t* rank_info, scclNodeInfo_t* node_info, scclRankPhysSet_t* rank_phys_set);

    // 额外处理nRanks个nodes的连接关系
    scclResult_t bootstrapNodesLink(void* node_info_vec, int node_info_total_bytes);

private:
    int rank, nRanks;           // 初始化阶段获取MPI的值
    int localRank, nLocalRanks; // 通过bootstrapRootGatherAndBroadcast函数确定值
    int interRank, nInterRanks; // 整个节点在全部节点中的位置

    // TODO: 用于控制套接字终端的变量，目前不知道在哪里使用
    volatile uint32_t* abortFlag; // 中止标志，非阻塞套接字设置

    // 外部传入的0号节点的基础信息
    const BootstrapHandle_t* root_handle;
    // 节点内所有进程的基础ip信息
    BootstrapNodeBasic_t* all_node_basic = nullptr;

    // 初始化标志
    bool socketInitDone;
    // 互斥锁，用于保护初始化过程的线程安全
    pthread_mutex_t bootstrapMutex = PTHREAD_MUTEX_INITIALIZER;
    pthread_cond_t bootstrapCond   = PTHREAD_COND_INITIALIZER;

    // 节点内通信的类
    scclIpcSocket_t* ipcsocket = nullptr; // 指向scclIpcSocket类实例的指针，初始值为nullptr
};

// 打印唯一的拓扑信息
scclResult_t printRankInfo(const std::string& prefix, scclRankInfo_t* info);

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
