#pragma once

#include <string.h>
#include "base.h"
#include "bootstrap.h"
#include "physical_links.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace graph {

typedef physical_links::scclTopoNode_t scclTopoNode_t;
typedef bootstrap::BootstrapComm_t BootstrapComm_t;
typedef topology::bootstrap::Bootstrap Bootstrap;

// 定义结构体 scclNodeInfo，用于存储每个rank的图连接信息
// TODO: 目前每个rank需要的node_info大小为4k+，当卡数较大时占用内存较大，可以优化。或者不作为全局变量
typedef struct scclNodeInfo {
    scclTopoNode_t* nodes; // 指向scclTopoNode_t对象数组的指针
    int nLocalRanks;
    int totalByteSize; // 表示占用的总字节数

    // 带参数的构造函数，用于初始化nodes的大小
    scclNodeInfo(int nLocalRanks) : nodes(nullptr), nLocalRanks(nLocalRanks), totalByteSize(sizeof(scclTopoNode_t) * topoNodeMaxLocalNodes / nLocalRanks) {
        nodes = reinterpret_cast<scclTopoNode_t*>(malloc(totalByteSize));
        if(nodes) {
            memset(nodes, 0, totalByteSize);
        }
    }

    // 析构函数，用于释放申请的数组空间
    virtual ~scclNodeInfo() {
        if(nodes) {
            free(nodes);
        }
    }
} scclNodeInfo_t;

//////////////////////////////////////////////////////////////////////////////////////////////////
// 定义 topoPathType_t 枚举类型，用于表示不同的路径类型。
typedef enum topoPathType {
    PATH_LOC = 0, // 本地路径
    PATH_NVL = 1, // 通过 NVLink 连接
    PATH_NVB = 2, // 通过中间 GPU 使用 NVLink 连接
    PATH_PIX = 3, // 通过最多一个 PCIe 桥连接
    PATH_PXB = 4, // 通过多个 PCIe 桥连接（不经过 PCIe 主桥）
    PATH_PXN = 5, // GPU 和 NIC 之间通过中间 GPU 连接， PXN = PCI + NVLink
    PATH_PHB = 6, // 通过 PCIe 以及 PCIe 主桥连接
    PATH_SYS = 7, // 通过 PCIe 以及 NUMA 节点之间的 SMP 互连连接
    PATH_NET = 8, // 通过网络连接
    PATH_DIS = 9  // 断开连接
} topoPathType_t;

// GPU 连接其他GPU硬件的直连类型
typedef enum LinkType : uint8_t {
    LINK_NONE = 0, // 本地路径
    LINK_LOC  = 1, // 本地路径
    LINK_NVL  = 2, // 通过 NVLink 连接
    LINK_PIX  = 3, // 通过 PCIe 桥连接
    LINK_PXN  = 4, // GPU 和 GPU 之间通过中间 NIC 连接，包括 PCIe 主桥
    LINK_NET  = 5  // 通过网络连接
} LinkType_t;

typedef struct scclTopoGraph {
    scclTopoGraph() = delete; // 删除默认构造函数
    scclTopoGraph(int nRanks);
    virtual ~scclTopoGraph();

    uint8_t* getTransportMapRowStart(int row) { return transport_map[row * nRanks]; }
    uint8_t* getTransportMapData(int row, int col) { return transport_map[row * nRanks + col]; }

    // 打印transport_map
    scclResult_t printTransportMap();

    // 打印gpu_paths信息的函数
    scclResult_t printGPUPaths();

public:
    // 使用无序映射存储图的有效节点
    std::unordered_map<uint64_t, scclTopoNode_t> graph_nodes;
    // 使用无序映射存储从每个GPU节点到其他GPU节点的所有路径，[start_node_id][end_node_id] = {path1, path2, ...}
    std::unordered_map<uint64_t, std::unordered_map<uint64_t, std::vector<std::vector<uint64_t>>>> gpu_paths;

    // 传输位图
    ByteSpanArray<uint8_t> transport_map; // 使用ByteSpanArray存储transport_map
    int nRanks;                           // 记录GPU节点的数量
} scclTopoGraph_t;

} // namespace graph
} // namespace topology
} // namespace hardware
} // namespace sccl
