#include <unistd.h>
#include <sys/types.h>
#include <string.h>
#include <sys/resource.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <chrono>
#include <ctime>
#include <cstdint>
#include <memory> // for std::unique_ptr

#include "bootstrap.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

////////////////////////////////////////////////////////////////////////////////////////////////////////
pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER; // 线程锁
static bool initialized  = false;                     // 标志是否已经初始化
bool hsaFineGrainFlag    = true;                      // 标志变量，用于指示是否启用HSAP细粒度标志

static scclResult_t basicInit() {
    // 如果已经初始化，直接返回成功
    if(asm_ops::ld_acquire_sys_global(&initialized))
        return scclSuccess;

    // 加锁以确保初始化过程的线程安全
    pthread_mutex_lock(&initLock);
    // 如果尚未初始化，进行初始化操作
    if(!initialized) {
        initEnv(); // 初始化环境
        // 始终初始化引导网络
        SCCLCHECK(bootstrapNet::bootstrapNetInit());
        // initGdrCopy(); // 初始化GDR复制
        // SCCLCHECK(scclNetPluginInit());

        char strValue[1024];
        // 检查NUMA自动平衡是否启用
        SCCLCHECK(scclTopoGetStrFromSys("/proc/sys/kernel", "numa_balancing", strValue));
        if(strcmp(strValue, "1") == 0)
            WARN("NUMA自动平衡已启用，这可能导致RCCL性能的不稳定性！通过\"sudo sysctl kernel.numa_balancing=0\"禁用");
        // 获取内核版本信息
        SCCLCHECK(scclTopoGetStrFromSys("/proc", "version", strValue));
        char *verStr, *state;
        verStr = strtok_r(strValue, " ", &state);
        for(int i = 0; i < 2; i++) {
            verStr = strtok_r(NULL, " ", &state);
            if(verStr == NULL)
                break;
        }
        INFO(SCCL_LOG_BOOTSTRAP, "内核版本: %s", verStr);
        // 检查是否为Cray系统
        if(strstr(verStr, "cray") == NULL) {
            // 获取BIOS版本信息
            SCCLCHECK(scclTopoGetStrFromSys("/sys/devices/virtual/dmi/id", "bios_version", strValue));
            if(strncmp("Hyper-V UEFI Release", strValue, 20) != 0) {
                FILE* file;
                // 读取内核命令行参数
                if((file = fopen("/proc/cmdline", "r")) != NULL) {
                    if(feof(file) == 0 && ferror(file) == 0) {
                        int len       = fread(strValue, 1, 1024, file);
                        strValue[len] = '\0';
                    }
                    fclose(file);
                }
                // 检查是否缺少"iommu=pt"参数
                if(strstr(strValue, "iommu=pt") == NULL)
                    WARN("内核命令行中缺少\"iommu=pt\"参数，这可能导致系统不稳定或挂起！");
            }

            float* ptr;
            // 尝试分配细粒度PCIe内存
            hipError_t err = hipExtMallocWithFlags((void**)&ptr, 128, hipDeviceMallocFinegrained);
            if(err != hipSuccess)
                hsaFineGrainFlag = false;
        }

        // 设置初始化标志
        asm_ops::st_release_sys_global(&initialized, true);
    }
    // 解锁
    pthread_mutex_unlock(&initLock);
    return scclSuccess;
}

scclResult_t bootstrapGetUniqueId(struct BootstrapHandle* handle) {
    SCCLCHECK(basicInit());
    // 在每个进程中设置 handle 的值
    getRandomData(&handle->magic, sizeof(handle->magic));

    const char* env = getenv("SCCL_COMM_ID");
    if(env) {
        memset(&handle->magic, 0, sizeof(handle->magic));
        INFO(SCCL_LOG_BOOTSTRAP, "SCCL_COMM_ID set by environment to %s", env);
        if(scclSocketGetAddrFromString(&handle->addr, env) != scclSuccess) {
            WARN("Invalid SCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
            return scclInvalidArgument;
        }
    } else {
        // 初始化socket
        scclSocketAddress_t localSocketAddr = bootstrapNet::getLocalSocketAddr();
        memcpy(&handle->addr, &localSocketAddr, sizeof(scclSocketAddress_t));
        // 启动根节点listen监听
        SCCLCHECK(bootstrapCreateRoot(handle));
    }
    return scclSuccess;
}

static scclResult_t setFilesLimit() {
    struct rlimit filesLimit;
    SYSCHECK(getrlimit(RLIMIT_NOFILE, &filesLimit), "getrlimit");
    filesLimit.rlim_cur = filesLimit.rlim_max;
    SYSCHECK(setrlimit(RLIMIT_NOFILE, &filesLimit), "setrlimit");
    return scclSuccess;
}

/**
 * @brief 根节点引导程序，负责收集所有rank的地址信息并广播给其他rank
 * 由于同一个socket数据传输比较慢，所以在进行数据广播时，仅传送给localRank==0的rank，再由其进行节点内广播
 * 该函数所有数据传输与 Bootstrap::bootstrapRootGatherAndBroadcast 函数相配合
 *
 * @param rargs 包含监听套接字和验证魔数的参数结构体
 * @return void* 总是返回NULL
 *
 * 该函数执行以下主要操作：
 * 1. 初始化资源并设置文件描述符限制
 * 2. 循环接收所有rank的连接请求，收集地址信息
 * 3. 验证接收到的rank信息一致性
 * 4. 计算本地rank数量(nLocalRanks)
 * 5. 使用线程池并行发送nLocalRanks值给所有rank
 * 6. 将收集到的所有rank地址信息广播给每个节点的localRank=0的进程
 * 7. 清理资源并返回
 *
 * @note 函数使用线程池加速消息分发，并通过日志记录关键操作步骤
 */
static void* bootstrapRoot(void* rargs) {
    struct bootstrapRootArgs* args = (struct bootstrapRootArgs*)rargs;
    scclSocket_t* listenSock       = args->listenSock; // 用于监听的套接字
    uint64_t magic                 = args->magic;      // 用于验证的魔数
    scclResult_t res               = scclSuccess;      // 函数结果
    class ThreadPool* pthread_pool = nullptr;          // 用于根节点分发消息的线程池

    int nRanks            = 0; // nRanks: 进程总数;
    int nLocalRanks       = 1;
    int c                 = 0; // c: 已连接的进程计数
    uint64_t rootHostHash = 0;
    struct BootstrapNodeBasic node_basic;                     // 用于存储扩展信息的结构体
    struct BootstrapNodeBasic* all_rank_node_basic = nullptr; // 所有进程的地址

    // 定义一个函数或者一个函数对象，用于执行实际的发送数据操作
    auto send_task = [](BootstrapNodeBasic& node_basic, uint64_t magic, int rank, void* data, size_t size) {
        net::net_socket::scclSocketClientManager client_manager(&node_basic.addr, magic, net::net_socket::scclSocketTypeBootstrap);
        bootstrapNet::bootstrapNetSend(client_manager.getSocket(), data, size);
    };

    // 用于验证的工具进行初始化
    scclSocketAddress_t* zero = nullptr;           // 用于初始化或比较的零地址
    setFilesLimit();                               // 设置文件描述符限制
    SCCLCHECKGOTO(scclCalloc(&zero, 1), res, out); // 为zero分配内存

    INFO(SCCL_LOG_BOOTSTRAP, "BEGIN"); // 日志：开始
    // --------------------- 1.从所有rank接收其socket地址（BootstrapNodeBasic） --------------------- //
    do {
        net::net_socket::scclSocketAcceptManager accept_manager(listenSock);
        SCCLCHECKGOTO(bootstrapNet::bootstrapNetRecv(accept_manager.getSocket(), &node_basic, sizeof(node_basic)), res, out); // 接收数据

        if(c == 0) {
            nRanks = node_basic.nRanks;
            SCCLCHECKGOTO(scclCalloc(&all_rank_node_basic, nRanks), res, out); // 为rankAddresses分配内存
            pthread_pool = new ThreadPool(nRanks);
        } else if(nRanks != node_basic.nRanks) {                                                           // 如果接收到的进程总数不匹配
            WARN("Bootstrap Root : mismatch in rank count from procs %d : %d", nRanks, node_basic.nRanks); // 警告
            goto out;                                                                                      // 跳转到out标签
        }
        if(memcmp(zero, &all_rank_node_basic[node_basic.rank].addr, sizeof(scclSocketAddress_t)) != 0) {  // 如果rank已经签到
            WARN("Bootstrap Root : rank %d of %d ranks has already checked in", node_basic.rank, nRanks); // 警告
            goto out;                                                                                     // 跳转到out标签
        }

        // 保存该rank的连接句柄
        memcpy(all_rank_node_basic + node_basic.rank, &node_basic, sizeof(struct BootstrapNodeBasic));
        ++c;                                                                                               // 增加已连接的进程计数
        INFO(SCCL_LOG_BOOTSTRAP, "Received connect from rank %d total %d/%d", node_basic.rank, c, nRanks); // 日志
    } while(c < nRanks); // 当已连接的进程数小于总数时循环
    INFO(SCCL_LOG_BOOTSTRAP, "COLLECTED ALL %d HANDLES", nRanks); // 日志：收集到所有句柄

    // --------------------- 2.计算nLocalRanks，并广播给其他所有rank --------------------- //
#if 1
    for(int r = 0; r < nRanks; ++r) {
        auto temp_node_basic = all_rank_node_basic[r];
        char line[100];
        sprintf(line, "bootstrapRoot r=%d, rank=%d/%d, hostHash=%lu,\n", r, temp_node_basic.rank, temp_node_basic.nRanks, temp_node_basic.hostHash);

        scclSocketAddress_t temp_addr = temp_node_basic.addr;
        hardware::net::printSocketAddr(&temp_addr, line);
    }
#endif
    // 首先计算nLocalRanks大小，即具有相同hostHash的节点数量
    rootHostHash = all_rank_node_basic[0].hostHash;
    for(int i = 1; i < nRanks; ++i) {
        if(rootHostHash == all_rank_node_basic[i].hostHash) {
            nLocalRanks++; // 如果hostHash相同，则增加本地节点计数
        } else {
            break; // 一旦发现不同的hostHash，停止计数
        }
    }
    // 给每个节点的localRank=0的进程发送信息，并由其进行广播，从而加快速度
    for(int r = 0; r < nRanks; ++r) {
        auto dst_node_basic = all_rank_node_basic[r];
        // 使用std::bind将参数绑定到send_task函数
        auto bound_task = std::bind(send_task, dst_node_basic, magic, r, &nLocalRanks, sizeof(int));
        // 将绑定后的任务添加到线程池
        pthread_pool->enqueue(bound_task);
    }
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
    }

    // --------------------- 3.给所有localRank==0的rank发送all_rank_node_basic数据 --------------------- //
    // 给每个节点的localRank=0的进程发送信息，并由其进行广播，从而加快速度
    for(int r = 0; r < nRanks / nLocalRanks; ++r) {
        int dst_rank        = r * nLocalRanks; // 计算目标rank
        auto dst_node_basic = all_rank_node_basic[dst_rank];
        net::net_socket::scclSocketClientManager client_manager(&dst_node_basic.addr, magic, net::net_socket::scclSocketTypeBootstrap);
        bootstrapNet::bootstrapNetSend(client_manager.getSocket(), all_rank_node_basic, sizeof(struct BootstrapNodeBasic) * nRanks);
        printf("root send nLocalRanks value to rank=%d\n", r);
    }
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
    }

    INFO(SCCL_LOG_BOOTSTRAP, "bootstrap send out all %d handles", nRanks); // 日志：发送出所有句柄
out:
    // 关闭套接字，并释放内存
    if(listenSock) {
        scclSocketClose(listenSock);
        delete listenSock;
    }
    // 释放内存
    if(all_rank_node_basic)
        free(all_rank_node_basic);
    if(zero)
        free(zero);
    if(pthread_pool)
        delete pthread_pool;
    free(rargs);                      // 释放rargs内存
    INFO(SCCL_LOG_BOOTSTRAP, "DONE"); // 日志：完成

    return NULL; // 返回NULL
}

/**
 * 创建并启动bootstrap根节点
 *
 * 该函数负责初始化监听socket，创建并启动一个独立的线程来处理bootstrap根节点逻辑。
 * 线程会被设置为detach状态，无需等待其结束。
 *
 * @param handle 包含bootstrap配置信息的句柄
 * @return 成功返回scclSuccess，失败返回相应的错误码
 */
scclResult_t bootstrapCreateRoot(struct BootstrapHandle* handle) {
    struct bootstrapRootArgs* args;
    pthread_t thread;

    // 设置根节点socket监听
    net::net_socket::scclSocketServerManager root_manager(&handle->addr, handle->magic, net::net_socket::scclSocketTypeBootstrap);

    // 为args分配内存
    SCCLCHECK(scclCalloc(&args, 1));
    // 设置线程参数
    args->listenSock = root_manager.releaseSocket();
    args->magic      = handle->magic;

    // 创建线程以执行bootstrapRoot函数， 直到线程结束才释放listenSock
    NEQCHECK(pthread_create(&thread, NULL, bootstrapRoot, (void*)args), 0);
    // 设置线程名称
    scclSetThreadName(thread, "SCCL BootstrapR");
    // 分离线程，使其在完成后自动回收资源
    NEQCHECK(pthread_detach(thread), 0); // will not be pthread_join()'d

    return scclSuccess;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////
// 构造函数
Bootstrap::Bootstrap(const struct BootstrapHandle* handle, int rank, int nRanks)
    : root_handle(handle), rank(rank), nRanks(nRanks), localRank(-1), nLocalRanks(0), socketInitDone(false) {
    printf("construct init Bootstrap\n");

    // 初始化线程池
    pthread_pool = new ThreadPool(nRanks);
}

Bootstrap::~Bootstrap() {
    if(pthread_pool) {
        delete pthread_pool;
    }
    if(ipcsocket) {
        delete ipcsocket;
    }
}

scclResult_t Bootstrap::init(struct BootstrapComm* bootstrap_comm) {
    // 如果已经初始化，直接返回成功
    if(asm_ops::ld_acquire_sys_global(&socketInitDone))
        return scclSuccess;

    // 加锁以确保初始化过程的线程安全
    pthread_mutex_lock(&bootstrapMutex);

    // -------------------------- 1.获取自身基础信息 ----------------------------------- //
    SCCLCHECK(basicInit());
    uint64_t hostHash                   = getHostHash();
    scclSocketAddress_t localSocketAddr = bootstrapNet::getLocalSocketAddr();
    // 设置基础信息
    struct BootstrapNodeBasic node_basic{rank, nRanks, hostHash, localSocketAddr};

    // -------------------------- 2.设置0号rank搜集的CPU信息和localRank信息 ----------------------------------- //
    // 创建根节点的数据收集
    std::vector<struct BootstrapNodeBasic> all_node_basic;
    all_node_basic.reserve(nRanks);
    SCCLCHECK(bootstrapRootGatherAndBroadcast(&node_basic, all_node_basic.data()));

    // -------------------------- 3.设置本地localRank的BootstrapComm信息 ----------------------------------- //
    // 初始化BootstrapComm类
    bootstrap_comm->init(rank, nRanks, localRank, nLocalRanks);

    if(CPU_COUNT(&bootstrap_comm->cpuAffinity)) {
        sched_setaffinity(0, sizeof(cpu_set_t), &bootstrap_comm->cpuAffinity);
    }
    bootstrap_comm->magic = root_handle->magic;

    //////// 设置显卡状态 ////////
    bootstrap_comm->hipDev = localRank; // CUDA 设备 ID
    uint32_t devices_num;
    SCCLCHECK(rocm_smi_init());                                // 初始化ROCM SMI库
    SCCLCHECK(rocm_smi_getNumDevice(&devices_num));            // 获取设备数量
    LTCHECK(devices_num, 0);                                   // 检查设备数量是否 devices_num>0
    LTCHECK(devices_num, nLocalRanks);                         // 检查设备数量是否 devices_num>nLocalRanks
    bootstrap_comm->deviceCnt = static_cast<int>(devices_num); // 将设备数量转换为int并赋值给的deviceCnt
    printf("devices_num=%d\n", bootstrap_comm->deviceCnt);

    LECHECK(devices_num, bootstrap_comm->hipDev);   // 检查hipDev是否小于deviceCnt
    HIPCHECK(hipSetDevice(bootstrap_comm->hipDev)); // 设置当前设备为hipDev

    //////// 设置启动通信的scclNet ////////
    // 获取环境变量SCCL_NET_NAME的值，如果不存在则默认使用"IB"
    const char* envNetName = getenv("SCCL_NET_NAME");
    char* netName          = (envNetName != NULL) ? strdup(envNetName) : strdup("IB");
    printf("netName=%s\n", netName);
    // 初始化网络和引导网络
    SCCLCHECK(net::scclNetInit(netName, bootstrap_comm->scclNet));
    // 释放分配的网络名称字符串
    free(netName);

    //////// 初始化唯一信息结构体 ////////
    struct scclNodeInfo local_node_info;
    local_node_info.hostHash = hostHash;
    SCCLCHECK(bootstrapCommInitNodeInfo(bootstrap_comm->scclNet, &local_node_info));

    printNodeInfo(&local_node_info);

    // -------------------------- 4.BootstrapComm信息的allgather ----------------------------------- //

    // bootstrapAllGather(bootstrap_comm->node_info);

    // 设置初始化标志
    asm_ops::st_release_sys_global(&socketInitDone, true);
    // 解锁
    pthread_mutex_unlock(&bootstrapMutex);

    return scclSuccess;
}

///////////////////////////////////////////////////////////////////////////
/**
 * @brief 执行根节点的数据收集和广播操作
 *
 * 该函数负责以下操作：
 * 1. 设置本地监听服务
 * 2. 向根节点发送本节点的基本数据
 * 3. 从根节点接收本地rank数量信息
 * 4. 当本地rank为0时，从根节点接收所有rank的IP数据
 * 5. 将收集到的所有rank数据广播给节点内其他rank
 *
 * @param send_data 发送给根节点的数据指针
 * @param recv_data 接收广播数据的缓冲区指针
 * @return scclResult_t 返回操作结果，成功返回scclSuccess
 */
scclResult_t Bootstrap::bootstrapRootGatherAndBroadcast(void* send_data, void* recv_data) {
    // 数据类型转换
    auto send_data_basic          = reinterpret_cast<struct BootstrapNodeBasic*>(send_data);
    auto recv_data_basic          = reinterpret_cast<struct BootstrapNodeBasic*>(recv_data);
    int recv_data_basic_size      = nRanks * sizeof(struct BootstrapNodeBasic);
    scclSocketAddress_t root_addr = root_handle->addr;

    // ------------- 1.各个rank在发送给根节点数据之前，首先设置监听listen ------------- //
    net::net_socket::scclSocketServerManager local_server_manager(&send_data_basic->addr, root_handle->magic, net::net_socket::scclSocketTypeBootstrap);
    // 保存监听的sock信息到本类，用于后续proxy使用
    my_listen_sock = local_server_manager.releaseSocket();

    // ------------- 2.各个节点向根节点发送数据 ------------- //
    net::net_socket::scclSocketClientManager client_manager(&root_addr, root_handle->magic, net::net_socket::scclSocketTypeBootstrap);
    SCCLCHECK(bootstrapNet::bootstrapNetSend(client_manager.getSocket(), send_data_basic, sizeof(struct BootstrapNodeBasic)));

    // ------------- 3.从根节点接收nLocalRanks值 ------------- //
    // 接收nLocalRanks信息
    {
        net::net_socket::scclSocketAcceptManager accept_manager(my_listen_sock);
        SCCLCHECK(bootstrapNet::bootstrapNetRecv(accept_manager.getSocket(), &nLocalRanks, sizeof(int)));
    }

    // ------------- 4.nLocalRanks==0时，从根节点接收所有rank的ip数据 ------------- //
    this->localRank = rank % nLocalRanks;
    if(localRank == 0) {
        net::net_socket::scclSocketAcceptManager accept_manager(my_listen_sock);
        SCCLCHECK(bootstrapNet::bootstrapNetRecv(accept_manager.getSocket(), recv_data_basic, recv_data_basic_size));
    }

    // ------------- 5.nLocalRanks==0时，将所有rank的ip数据广播给节点内其他rank ------------- //
    ipcsocket = new scclIpcSocket_t(localRank, nLocalRanks, /*hash*/ root_handle->magic);
    ipcsocket->scclIpcSocketBroadcast(recv_data_basic, recv_data_basic_size, /*localRank root*/ 0);

    return scclSuccess;
}

/**
 * @brief 初始化节点通信信息
 *
 * 该函数用于初始化节点的通信信息，包括：
 * - 设置节点的全局排名和本地排名
 * - 获取并设置进程ID哈希值
 * - 设置GPU设备属性（名称、GCN架构、计算能力）
 * - 设置RDMA网络属性
 * - 设置PCI总线ID
 * - 设置CPU套接字地址
 *
 * @param scclNet 网络句柄
 * @param socket_addr 套接字地址
 * @param node_info 节点信息结构体指针
 * @return scclResult_t 返回操作结果，成功返回scclSuccess
 */
scclResult_t Bootstrap::bootstrapCommInitNodeInfo(scclNet_t* scclNet, struct scclNodeInfo* node_info) {
    ////////////////// 设置基础信息 //////////////////
    node_info->rank      = rank;         // 当前节点的全局排名
    node_info->localRank = localRank;    // 当前节点在本地计算节点中的排名
    node_info->pidHash   = getPidHash(); // 获取进程ID哈希值并赋值给的pidHash
    int hipDev           = localRank;

    ////////////////// 设置硬件信息 //////////////////
    struct topoLocalNode* p_localNode = &node_info->localNode;

    // 设置CPU信息
    p_localNode->cpu.listen_sock = *my_listen_sock;

    // 设置PCI信息
    SCCLCHECK(getBusId(hipDev, &p_localNode->pci.busId));

    // 设置GPU信息
    p_localNode->gpu.dev = hipDev;
    hipDeviceProp_t deviceProp;
    HIPCHECK(hipGetDeviceProperties(&deviceProp, hipDev));
    snprintf(p_localNode->gpu.name, sizeof(p_localNode->gpu.name), "%s", deviceProp.name);
    snprintf(p_localNode->gpu.gcn, sizeof(p_localNode->gpu.gcn), "%s", deviceProp.gcnArchName);
    p_localNode->gpu.compCap = deviceProp.major * 10 + deviceProp.minor;

    // 设置RDMA信息
    SCCLCHECK(scclNet->getProperties(hipDev, &p_localNode->net.props));
    SCCLCHECK(scclNet->devices(&p_localNode->net.count));

    return scclSuccess;
}

scclResult_t Bootstrap::bootstrapAllGather(struct scclNodeInfo*) {
    // 1.节点内通信 allgather

    // 2.节点间通信，ring allgather

    // 3.节点内通信 allgather

    return scclSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////////////
// // 将本地socket地址写入到/tmp/文件夹的文件中，通过nfs共享存储，其他rank可见
// scclResult_t bootstrapGetAllNodes(const struct scclNodeInfo* , struct BootstrapComm* comm) {
//     // // 分配并初始化IPC套接字
//     // struct scclIpcSocket ipcSock = {0};
//     // // Create a UDS socket to receive the converted fd
//     // SCCLCHECK(scclIpcSocketInit(&ipcSock, ->rank, /*hash*/ handle->magic, /*abortFlag*/ NULL));
//     // printf("fd=%d, socketName=%s\n", ipcSock.fd, ipcSock.socketName);
//     return scclInProgress;
// }

// auto node_info = bootstrap_comm->node_info;
// // 设置节点内socket通信工具
// ipcsocket = new scclIpcSocket_t(node_info->localRank, node_info->nRanks, node_info->hostHash, bootstrap_comm->abortFlag);

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
