#pragma once

#include <string.h>
#include <cstddef>
#include <vector>
#include "base.h"
#include "topo_utils.h"
#include "comm.h"
#include "rocm_smi_wrap.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

typedef union net::net_socket::scclSocketAddress scclSocketAddress_t;
typedef struct net::net_socket::scclSocket scclSocket_t;
typedef net::scclNet_t scclNet_t;

// 用于初始化时广播0号rank的地址信息
struct BootstrapHandle {
    uint64_t magic = 0;       // 随机码，用于socket通信
    scclSocketAddress_t addr; // 地址，用于网络通信
};

#define SCCL_UNIQUE_ID_BYTES (40) // sizeof(BootstrapHandle)
typedef struct {
    char internal[SCCL_UNIQUE_ID_BYTES];
} scclUniqueId;

// 仅用于初始化的函数bootstrapCreateRoot，用于传递detach线程的参数
struct bootstrapRootArgs {
    uint64_t magic;
    scclSocket_t* listenSock = nullptr; // 根节点的监听
};

// 用于初始建立连接阶段，0号rank之外的进程向其传递的信息
struct BootstrapNodeBasic {
    int rank;
    int nRanks;               // 进程的总数量
    uint64_t hostHash;        // 用于区分host的CPU编号
    scclSocketAddress_t addr; // 各个进程的监听套接字地址，用于网络通信
};

// 定义每个rank所持有的所有拓扑节点
struct topoLocalNode {
    struct {
        scclSocket_t listen_sock; // 监听套接字
    } cpu;                        // CPU节点
    struct {
        int64_t busId; // PCI总线ID以int64_t格式表示
    } pci;             // pci节点
    struct {
        int dev;      // NVML设备编号
        char name[8]; // 设备名称
        char gcn[7];  // GCN架构名称
        int compCap;  // CUDA计算能力
    } gpu;            // GPU节点
    struct {
        int count; // 网卡数量
        net::scclNetProperties_t props;
    } net; // 网络节点
};

// 定义结构体 scclNodeInfo，用于存储每个rank的通信节点的信息
struct scclNodeInfo {
    struct topoLocalNode localNode;
    int rank      = -1; // 当前节点的全局排名
    int localRank = -1; // 当前节点在本地计算节点中的排名

    uint64_t hostHash = 0; // 主机哈希值
    uint64_t pidHash  = 0; // 进程 ID 哈希值
};

// 每个节点的信息
struct scclNodeInfoSet {
    int nUniqueInfos; // 通信节点的数量
    std::vector<struct scclNodeInfo> node_info_vec;

    // 构造函数声明
    scclNodeInfoSet(int nRanks);
};

// BootstrapComm 结构体定义，用于存储引导通信信息
struct BootstrapComm {
    void init(int rank, int nRanks, int localRank, int nLocalRanks);
    void destroy();

public:
    scclNet_t* scclNet;
    struct scclNodeInfoSet* node_info_set;

    cpu_set_t cpuAffinity; // CPU亲和性
    int rank        = -1;  // 当前节点的全局排名
    int nRanks      = 0;   // 总的节点数量
    int localRank   = -1;  // 当前节点在本地计算节点中的排名
    int nLocalRanks = 0;   // 本地计算节点中的节点总数
    int hipDev      = -1;  // CUDA 设备 ID
    int deviceCnt   = 0;   // 设备数量

    // proxy通信
    uint64_t magic;               // 魔术数，用于验证结构体
    volatile uint32_t* abortFlag; // 中止标志，非阻塞套接字设置

    // int splitShare;      // 是否使用共享内存进行分割
    // int* topParentRanks; // 顶级父节点的rank
    // /* 与代理相关的共享资源 */
    // struct scclProxyState* proxyState;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// 获取主机唯一标识的哈希值，该哈希值在裸机和容器实例中都是唯一的
uint64_t getHostHash(void);

// 获取当前进程的唯一哈希标识符
uint64_t getPidHash(void);

// 从/dev/urandom设备获取随机数据填充缓冲区
scclResult_t getRandomData(void* buffer, size_t bytes);

// 获取指定CUDA设备的PCI总线ID并转换为64位整数
scclResult_t getBusId(int hipDev, int64_t* busId);

// 获取当前HIP设备的计算能力版本号
int scclCudaCompCap(void);

// 打印唯一的拓扑信息
scclResult_t printNodeInfo(struct scclNodeInfo* info);

// 实现类似于std::span的功能，将字节数组转换为类型数组
template <typename T>
class ByteSpan {
public:
    ByteSpan(const char* data, std::size_t size) : data_(reinterpret_cast<const T*>(data)), size_(size / sizeof(T)) {}

    const T* data() const { return data_; }
    std::size_t size() const { return size_; }

    const T& operator[](std::size_t index) const { return data_[index]; }

private:
    const T* data_;
    std::size_t size_;
};

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
