#pragma once

#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <sys/types.h>
#include <unistd.h>

#include "base.h"
#include "net_utils.h"
#include "socket.h"

namespace sccl {
namespace hardware {
namespace net {
namespace net_socket {

/* Communication functions */
static constexpr int MAX_SOCKETS   = 64;
static constexpr int MAX_THREADS   = 16;
static constexpr int MAX_REQUESTS  = SCCL_NET_MAX_REQUESTS;
static constexpr int MIN_CHUNKSIZE = (64 * 1024);

enum scclNetSocketCommState : uint8_t {
    scclNetSocketCommStateStart   = 0,
    scclNetSocketCommStateConnect = 1,
    scclNetSocketCommStateAccept  = 3,
    scclNetSocketCommStateSend    = 4,
    scclNetSocketCommStateRecv    = 5,
};

struct scclNetSocketCommStage {
    enum scclNetSocketCommState state;
    uint8_t iteration;
    struct scclSocket* sock;
    struct scclNetSocketComm* comm = nullptr;
};

struct scclNetSocketHandle {
    union scclSocketAddress connectAddr;
    uint64_t magic; // random number to help debugging
    int nSocks;
    int nThreads;
    struct scclNetSocketCommStage stage;
};

struct scclNetSocketTask {
    int op;
    void* data;
    int size;
    struct scclSocket* sock = nullptr;
    int offset;
    int used;
    scclResult_t result;
};

struct scclNetSocketRequest {
    int op;
    void* data;
    int size;
    struct scclSocket* ctrlSock = nullptr;
    int offset;
    int used;
    struct scclNetSocketComm* comm               = nullptr;
    struct scclNetSocketTask* tasks[MAX_SOCKETS] = {nullptr};
    int nSubs;
};

struct scclNetSocketTaskQueue {
    int next;
    int len;
    struct scclNetSocketTask* tasks = nullptr;
};

struct scclNetSocketThreadResources {
    struct scclNetSocketTaskQueue threadTaskQueue;
    int stop;
    struct scclNetSocketComm* comm = nullptr;
    pthread_mutex_t threadLock;
    pthread_cond_t threadCond;
};

struct scclNetSocketListenComm {
    struct scclSocket sock;
    struct scclNetSocketCommStage stage;
    int nSocks;
    int nThreads;
    int dev;
};

struct scclNetSocketComm {
    struct scclSocket ctrlSock;
    struct scclSocket socks[MAX_SOCKETS];
    int dev;
    int hipDev;
    int nSocks;
    int nThreads;
    int nextSock;
    struct scclNetSocketRequest requests[MAX_REQUESTS];
    pthread_t helperThread[MAX_THREADS];
    struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};

//////////////////////////////////
class scclNetSocket : public scclNetBase {
public:
    // 构造函数和析构函数
    scclNetSocket();
    virtual ~scclNetSocket();

    // 初始化网络。
    scclResult_t init() override;
    // 返回适配器的数量。
    scclResult_t devices(int* ndev) override;
    // 获取各种设备属性。
    scclResult_t getProperties(int dev, scclNetProperties_t* props) override;
    // 创建一个接收对象并提供一个句柄以连接到它。该句柄最多可以是 SCCL_NET_HANDLE_MAXSIZE 字节，并将在排名之间交换以创建连接。
    scclResult_t listen(int dev, void* handle, void** listenComm) override;
    // 连接到一个句柄并返回一个发送 comm 对象给该对等体。
    // 此调用不应阻塞以建立连接，而应成功返回 sendComm == NULL，并期望再次调用直到 sendComm != NULL。
    scclResult_t connect(int dev, void* handle, void** sendComm) override;
    // 在远程对等体调用 connect 后最终确定连接建立。
    // 此调用不应阻塞以建立连接，而应成功返回 recvComm == NULL，并期望再次调用直到 recvComm != NULL。
    scclResult_t accept(void* listenComm, void** recvComm) override;
    // 注册/注销内存。Comm 可以是 sendComm 或 recvComm。
    // 类型是 SCCL_PTR_HOST 或 SCCL_PTR_CUDA。
    scclResult_t regMr(void* comm, void* data, int size, int type, void** mhandle) override;
    /* DMA-BUF 支持 */
    scclResult_t regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) override;
    // 注销IB内存区域(MR)
    scclResult_t deregMr(void* comm, void* mhandle) override;
    // 异步发送到对等体。
    // 如果调用不能执行（或会阻塞），则可能返回 request == NULL
    scclResult_t isend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) override;
    // 异步从对等体接收。 如果调用不能执行（或会阻塞），则可能返回 request == NULL
    scclResult_t irecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) override;
    // 执行刷新/栅栏操作，以确保所有使用 SCCL_PTR_CUDA 接收到的数据对 GPU 可见
    scclResult_t iflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) override;
    // 测试请求是否完成。如果 size 不为 NULL，则返回发送/接收的字节数。
    scclResult_t test(void* request, int* done, int* sizes) override;
    // 关闭并释放 send/recv comm 对象
    scclResult_t closeSend(void* sendComm) override;
    scclResult_t closeRecv(void* recvComm) override;
    scclResult_t closeListen(void* listenComm) override;

private:
    struct scclNetSocketListenComm* socketComm = nullptr;

private:
    // 获取网络设备的PCI路径
    static scclResult_t scclNetSocketGetPciPath(char* devName, char** pciPath);
    // 获取指定网络设备的速度（单位：Mbps）
    scclResult_t scclNetSocketGetSpeed(char* devName, int* speed);
    // 持久化socket线程处理函数
    static void* persistentSocketThread(void* args_);
    // 为指定通信对象创建并获取一个网络套接字任务
    scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req);
    // 获取指定设备的socket和线程数量配置
    scclResult_t scclNetSocketGetNsockNthread(int dev, int* ns, int* nt);
};

} // namespace net_socket
} // namespace net
} // namespace hardware
} // namespace sccl
