#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <vector> // 引入vector库
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "mpi.h"
#include "net.h"
#include "ipc_socket.h"
#include "thread_pool.h"

using namespace sccl;
typedef class sccl::hardware::net::ipc_socket::scclIpcSocket scclIpcSocket_t;

template <typename T>
void send_data(T* ipcsocket, const void* data, size_t dataLen, int dst_rank) {
    if(ipcsocket->scclIpcSocketSendData(data, dataLen, dst_rank) != scclSuccess) {
        perror("Failed to send data");
        MPI_Abort(MPI_COMM_WORLD, 1);
    }
}

template <typename T>
void recv_data(T* ipcsocket, void* buffer, size_t bufferLen, size_t* receivedLen) {
    if(ipcsocket->scclIpcSocketRecvData(buffer, bufferLen, receivedLen) != scclSuccess) {
        perror("Failed to receive data");
        MPI_Abort(MPI_COMM_WORLD, 1);
    }
}

template <typename T>
int test_allgather_ver1(T* ipcsocket, int rank, int size) {
    int sendDataLen = 256;
    std::vector<char> sendData(sendDataLen);
    std::vector<char> recvData(size * sendDataLen);
    size_t receivedLen;

    // 填充发送数据
    snprintf(sendData.data(), sendData.size(), "Data from process %d", rank);

    auto pthpool = ThreadPool(size * 2);

    // 发送数据给所有其他进程
    for(int i = 0; i < size; ++i) {
        if(i != rank) {
            auto task_send = std::bind(send_data<scclIpcSocket_t>, ipcsocket, sendData.data(), sendData.size(), i);
            pthpool.enqueue(task_send);

            auto task_recv = std::bind(recv_data<scclIpcSocket_t>, ipcsocket, recvData.data() + i * sendDataLen, sendDataLen, &receivedLen);
            pthpool.enqueue(task_recv);
        }
    }

    printf("sendData.size()=%d, receivedLen=%d\n", sendDataLen, int(receivedLen));

    // 打印接收到的数据
    for(int i = 0; i < size; ++i) {
        printf("Process %d received from process %d: %s\n", rank, i, recvData.data() + i * 256);
    }

    return 0;
}

template <typename T>
int test_allgather_ver2(T* ipcsocket, int rank, int size) {
    int sendDataLen = 256;
    std::vector<char> sendData(sendDataLen);
    std::vector<char> recvData(size * sendDataLen);

    // 填充发送数据
    snprintf(sendData.data(), sendData.size(), "Data from process %d", rank);
    SCCLCHECK(ipcsocket->scclIpcSocketAllgatherSync(sendData.data(), recvData.data(), sendData.size(), /*wait*/ true));

    // 打印接收到的数据
    for(int i = 0; i < size; ++i) {
        printf("rank %d received from process %d: %s\n", rank, i, recvData.data() + i * sendData.size());
    }

    return 0;
}

template <typename T>
int test_allgather_ver3(T* ipcsocket, int rank, int size) {
    int sendDataLen = 256;
    std::vector<char> sendData(sendDataLen);
    std::vector<char> recvData(size * sendDataLen);

    // 填充发送数据
    snprintf(sendData.data(), sendData.size(), "Data from process %d", rank);
    SCCLCHECK(ipcsocket->scclIpcSocketAllgather(sendData.data(), recvData.data(), sendData.size()));

    // 打印接收到的数据
    for(int i = 0; i < size; ++i) {
        printf("rank %d received from process %d: %s\n", rank, i, recvData.data() + i * sendData.size());
    }

    return 0;
}

template <typename T>
int test_broadcast_ver1(T* ipcsocket, int rank, int size) {
    int sendDataLen = 256;
    std::vector<char> sendData(sendDataLen);
    std::vector<char> recvData(sendDataLen);
    int root = 0; // 假设 rank 0 是根进程

    if(rank == root) {
        // 仅根进程填充发送数据
        snprintf(sendData.data(), sendData.size(), "Data from root process %d", rank);
    }

    SCCLCHECK(ipcsocket->scclIpcSocketBroadcast(sendData.data(), recvData.data(), sendData.size(), root, /*wait*/ true));

    // 打印接收到的数据
    printf("rank %d received: %s\n", rank, recvData.data());

    return 0;
}

int main(int argc, char* argv[]) {
    MPI_Init(&argc, &argv);
    int rank, size;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    int dst_hash = 12345;

    scclIpcSocket_t* ipcsocket = new scclIpcSocket_t(rank, size, dst_hash);

    // test_allgather_ver1(ipcsocket, rank, size);
    // test_allgather_ver2(ipcsocket, rank, size);
    // test_allgather_ver3(ipcsocket, rank, size);
    test_broadcast_ver1(ipcsocket, rank, size);

    std::this_thread::sleep_for(std::chrono::seconds(10));
    // while(!ipcsocket->getPthreadPool()->allTasksCompleted()) {}
    // printf("delete ipcsocket... rank=%d\n", rank);

    delete(ipcsocket);
    MPI_Finalize();
    return 0;
}

/*
单机执行
SCCL_DEBUG_LEVEL=ABORT SCCL_DEBUG_SUBSYS=BOOTSTRAP mpirun --allow-run-as-root -np 8 3_socket_mpi_data
*/
