#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <vector> // 引入vector库
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "mpi.h"
#include "net.h"
#include "ipc_socket.h"
#include "thread_pool.h"

using namespace sccl;
typedef class sccl::hardware::net::ipc_socket::scclIpcSocket scclIpcSocket_t;

template <typename T>
int test_allgather(T* ipcsocket, int rank, int size, int dataLen = 64 * 1024, int num_iterations = 1) {
    std::vector<char> sendData(dataLen);
    std::vector<char> recvData(size * dataLen);

    // 填充发送数据
    snprintf(sendData.data(), sendData.size(), "Data from process %d", rank);
    printf("test_allgather dataLen=%d, sendData.size()=%zu\n", dataLen, sendData.size());

    std::vector<double> elapsed_times; // 用于存储每次执行的耗时

    // 开始计时
    auto start = std::chrono::high_resolution_clock::now();

    // 调用 Allgather 函数
    for(int i = 0; i < num_iterations; ++i) {
        SCCLCHECK(ipcsocket->scclIpcSocketAllgather(sendData.data(), recvData.data(), dataLen));
    }

    // 结束计时
    auto end = std::chrono::high_resolution_clock::now();

    // 所有进程在此处等待，直到所有进程都到达这一点
    MPI_Barrier(MPI_COMM_WORLD);

    // 计算并存储每个进程的计时结果
    std::chrono::duration<double> elapsed = end - start;

    auto average_time = elapsed.count() * 1e6 / num_iterations; // 转换为微秒
    printf("rank %d: Average time for Allgather over %d iterations: %f us.\n", rank, num_iterations, average_time);

    // 打印接收到的数据
    for(int i = 0; i < size; ++i) {
        printf("rank %d received from process %d: %s\n", rank, i, recvData.data() + i * sendData.size());
    }

    return 0;
}

template <typename T>
int test_broadcast(T* ipcsocket, int rank, int size, int dataLen = 64 * 1024, int num_iterations = 1) {
    std::vector<char> data(dataLen);
    int root = 0; // 假设 rank 0 是根进程
    if(rank == root) {
        // 仅根进程填充发送数据
        snprintf(data.data(), data.size(), "Data from root process %d", rank);
    }
    printf("rank=%d, data.size()=%zu\n", rank, data.size());

    std::vector<double> elapsed_times; // 用于存储每次执行的耗时

    // 开始计时
    auto start = std::chrono::high_resolution_clock::now();

    for(int i = 0; i < num_iterations; ++i) {
        SCCLCHECK(ipcsocket->scclIpcSocketBroadcast(data.data(), data.size(), root));
    }

    // 结束计时
    auto end = std::chrono::high_resolution_clock::now();

    // 所有进程在此处等待，直到所有进程都到达这一点
    MPI_Barrier(MPI_COMM_WORLD);

    // 计算并存储每个进程的计时结果
    std::chrono::duration<double> elapsed = end - start;

    auto average_time = elapsed.count() * 1e6 / num_iterations; // 转换为微秒
    printf("rank %d: data=%s, Average time for scclIpcSocketBroadcast over %d iterations: %f us.\n", rank, (char*)(data.data()), num_iterations, average_time);

    return 0;
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);

    int rank, size;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    int dst_hash               = 654321;
    scclIpcSocket_t* ipcsocket = new scclIpcSocket_t(rank, size, dst_hash);

    // 默认参数
    std::string test_type = "allgather";
    int dataLen           = 64 * 1024;
    int num_iterations    = 1;

    // 解析命令行参数
    for(int i = 1; i < argc; ++i) {
        std::istringstream iss(argv[i]);
        std::string arg;
        iss >> arg;

        if(arg == "--test-type") {
            if(++i < argc) {
                test_type = argv[i];
            }
        } else if(arg == "--data-len") {
            if(++i < argc) {
                iss.clear();
                iss.str(argv[i]);
                iss >> dataLen;
            }
        } else if(arg == "--num-iterations") {
            if(++i < argc) {
                iss.clear();
                iss.str(argv[i]);
                iss >> num_iterations;
            }
        }
    }

    if(test_type == "allgather") {
        test_allgather(ipcsocket, rank, size, dataLen, num_iterations);
    } else if(test_type == "broadcast") {
        test_broadcast(ipcsocket, rank, size, dataLen, num_iterations);
    } else {
        if(rank == 0) {
            std::cerr << "Unknown test type: " << test_type << std::endl;
        }
    }

    delete ipcsocket;
    MPI_Finalize();
    return 0;
}

/*
单机执行
SCCL_DEBUG_LEVEL=WARN SCCL_DEBUG_SUBSYS=GRAPH mpirun --allow-run-as-root -np 8 3_socket_mpi_data
SCCL_DEBUG_LEVEL=WARN SCCL_DEBUG_SUBSYS=GRAPH mpirun --allow-run-as-root -np 4 3_socket_mpi_data
*/
