#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <mpi.h>

#include "bootstrap.h"
#include "hardware.h"

using namespace sccl;
typedef sccl::hardware::topology::bootstrap::scclUniqueId scclUniqueId;
typedef sccl::hardware::topology::bootstrap::BootstrapHandle_t BootstrapHandle_t;
typedef sccl::hardware::topology::bootstrap::Bootstrap Bootstrap;
// 全局变量
struct sccl::hardware::topology::bootstrap::BootstrapComm bootstrap_comm;

scclResult_t sccl_init_step1(const scclUniqueId* unique_id, int rank, int nRanks) {
    // -------------------------- 1.获取0号rank的地址信息 ----------------------------------- //
    auto root_handle = reinterpret_cast<const BootstrapHandle_t*>(unique_id);
    EQCHECK(root_handle->magic, 0); // 检查handle是否已经更新

    // -------------------------- 2.初始化获取所有节点的node信息 ----------------------------------- //
    auto sccl_bootstrap = std::make_unique<Bootstrap>(root_handle, rank, nRanks);
    SCCLCHECK(sccl_bootstrap->init(&bootstrap_comm));

    return scclSuccess;
}

constexpr int topoNodeMaxNeighbors = 16;
typedef struct topoNode {
    uint64_t id;               // 图点id标志
    int type;                  // 图点类型
    int numaId;                // 节点id
    char busIdStr[17] = "";    // 总线ID字符串 "00000000:00:00.0"
    int speed;                 // 速度
    int width;                 // 带宽
    char cpuAffinity[36] = ""; // cpu的affinity

    std::array<uint64_t, topoNodeMaxNeighbors> neighbors; // 邻居图点
    size_t neighborCount;                                 // 邻居图点的数量
} topoNode_t;

int main(int argc, char* argv[]) {
    // -------------------------- 1.启动MPI ----------------------------------- //
    MPI_Init(&argc, &argv);
    int rank, nRanks;
    MPI_Comm_size(MPI_COMM_WORLD, &nRanks);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    printf("rank=%d, nRanks=%d\n", rank, nRanks);
    int nLocalRanks = 2;

    BootstrapHandle_t uqid;
    printf("uqid size=%lu\n", sizeof(uqid));
    sccl::hardware::topology::bootstrap::scclRankInfo_t rankinfo;
    sccl::hardware::topology::bootstrap::scclNodeInfo_t nodeinfo(nLocalRanks);
    topoNode_t topo_node;
    printf("rankinfo size=%lu\n", sizeof(rankinfo));
    printf("rankinfo cpu size=%lu\n", sizeof(rankinfo.cpu));
    printf("rankinfo gpu size=%lu\n", sizeof(rankinfo.gpu));
    printf("rankinfo net size=%lu\n", sizeof(rankinfo.net));
    printf("nodeinfo size=%lu, stu size=%d\n", sizeof(nodeinfo), nodeinfo.size);
    printf("topo_node size=%lu\n", sizeof(topo_node));

    // -------------------------- 2.获取节点unique_id，主要是socket地址 ----------------------------------- //
    scclUniqueId unique_id;
    if(rank == 0) {
        SCCLCHECK(sccl::hardware::scclGetUniqueId(&unique_id));
    }
    MPI_Bcast(&unique_id, sizeof(scclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD);

    // -------------------------- 3.基于unique_id的整合结果初始化 ----------------------------------- //
    sccl_init_step1(&unique_id, rank, nRanks);

    int cuda_id;
    HIPCHECK(hipGetDevice(&cuda_id));
    printf("rank=%d, cuda_id=%d\n", rank, cuda_id);

    MPI_Barrier(MPI_COMM_WORLD);

    SCCLCHECK(sccl::hardware::sccl_finalize());
    MPI_Finalize();
}

/*
单机执行
SCCL_DEBUG_LEVEL=ABORT mpirun --allow-run-as-root -np 4 2_mpi_init_mpi_init_step1_bootstrap
SCCL_DEBUG_LEVEL=INFO SCCL_DEBUG_SUBSYS=ALL mpirun --allow-run-as-root -np 2 2_mpi_init_mpi_init_step1_bootstrap

跨机执行
SCCL_DEBUG_LEVEL=WARN SCCL_DEBUG_SUBSYS=BOOTSTRAP mpirun --allow-run-as-root --hostfile hostfile2 -np 4 ./2_mpi_init_mpi_init_step1_bootstrap
SCCL_DEBUG_LEVEL=WARN SCCL_DEBUG_SUBSYS=BOOTSTRAP mpirun --allow-run-as-root --hostfile hostfile -np 16 ./2_mpi_init_mpi_init_step1_bootstrap
*/
