#include "socket.h"
#include "debug.h"
#include "check.h"

#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>

using namespace sccl;

#define MAX_REQUESTS 8
#define MAX_THREADS 16
#define MAX_SOCKETS 64
struct scclNetSocketTask {
    int op;
    void* data;
    int size;
    struct scclSocket* sock;
    int offset;
    int used;
    scclResult_t result;
};
struct scclNetSocketTaskQueue {
    int next;
    int len;
    struct scclNetSocketTask* tasks;
};
struct scclNetSocketRequest {
    int op;
    void* data;
    int size;
    struct scclSocket* ctrlSock;
    int offset;
    int used;
    struct scclNetSocketComm* comm;
    struct scclNetSocketTask* tasks[MAX_SOCKETS];
    int nSubs;
};
struct scclNetSocketThreadResources {
    struct scclNetSocketTaskQueue threadTaskQueue;
    int stop;
    struct scclNetSocketComm* comm;
    pthread_mutex_t threadLock;
    pthread_cond_t threadCond;
};
struct scclNetSocketComm {
    struct scclSocket ctrlSock;
    struct scclSocket socks[MAX_SOCKETS];
    int dev;
    int hipDev;
    int nSocks;
    int nThreads;
    int nextSock;
    struct scclNetSocketRequest requests[MAX_REQUESTS];
    pthread_t helperThread[MAX_THREADS];
    struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};

#define DIVUP(x, y) (((x) + (y) - 1) / (y))
#define MIN_CHUNKSIZE (64 * 1024)

template <typename T>
scclResult_t scclCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
    void* p = malloc(nelem * sizeof(T));
    if(p == NULL) {
        WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
        return scclSystemError;
    }
    memset(p, 0, nelem * sizeof(T));
    *ptr = (T*)p;
    return scclSuccess;
}
#define scclCalloc(...) scclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)

void scclSetThreadName(pthread_t thread, const char* fmt, ...) {
#ifdef _GNU_SOURCE
    char threadName[16];
    va_list vargs;
    va_start(vargs, fmt);
    vsnprintf(threadName, 16, fmt, vargs);
    va_end(vargs);
    pthread_setname_np(thread, threadName);
#endif
}

void* persistentSocketThread(void* args_) {
    struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_;
    struct scclNetSocketComm* comm                = resource->comm;
    struct scclNetSocketTaskQueue* myQueue        = &resource->threadTaskQueue;
    int nSocksPerThread                           = comm->nSocks / comm->nThreads;
    while(1) {
        int idle = 1;
        int mark = myQueue->next; // mark newest task seen
        for(int i = 0; i < myQueue->len; i += nSocksPerThread) {
            int repeat;
            do {
                repeat = 0;
                for(int j = 0; j < nSocksPerThread; j++) {
                    struct scclNetSocketTask* r = myQueue->tasks + i + j;
                    if(r != NULL && r->used == 1 && r->offset < r->size) {
                        r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset);
                        if(r->result != scclSuccess) {
                            WARN("NET/Socket : socket progress error");
                            return NULL;
                        }
                        idle = 0;
                        if(r->offset < r->size)
                            repeat = 1;
                    }
                }
            } while(repeat);
        }
        if(idle) {
            pthread_mutex_lock(&resource->threadLock);
            while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait
                pthread_cond_wait(&resource->threadCond, &resource->threadLock);
            }
            pthread_mutex_unlock(&resource->threadLock);
        }
        if(resource->stop)
            return NULL;
    }
}

scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) {
    int tid                                  = comm->nextSock % comm->nThreads;
    struct scclNetSocketThreadResources* res = comm->threadResources + tid;
    struct scclNetSocketTaskQueue* queue     = &res->threadTaskQueue;
    // create helper threads and prepare per-thread task queue
    if(queue->tasks == NULL) {
        // each request can be divided up to nSocks tasks, and
        // these tasks are distributed to nThreads threads,
        // we need to make sure each thread queue has enough slots for MAX_REQUESTS
        queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
        SCCLCHECK(scclCalloc(&queue->tasks, queue->len));
        queue->next = 0;
        res->comm   = comm;
        pthread_mutex_init(&res->threadLock, NULL);
        pthread_cond_init(&res->threadCond, NULL);
        pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res);
        scclSetThreadName(comm->helperThread[tid], "NCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->hipDev);
    }
    struct scclNetSocketTask* r = queue->tasks + queue->next;
    if(r->used == 0) {
        r->op          = op;
        r->data        = data;
        r->size        = size;
        r->sock        = comm->socks + comm->nextSock;
        r->offset      = 0;
        r->result      = scclSuccess;
        comm->nextSock = (comm->nextSock + 1) % comm->nSocks;
        r->used        = 1;
        *req           = r;
        pthread_mutex_lock(&res->threadLock);
        queue->next = (queue->next + 1) % queue->len;
        pthread_cond_signal(&res->threadCond);
        pthread_mutex_unlock(&res->threadLock);
        return scclSuccess;
    }
    WARN("NET/Socket : unable to allocate subtasks");
    return scclInternalError;
}

/**
 * @brief 测试socket通信请求状态
 *
 * 该函数用于测试socket通信请求的完成状态，并处理数据传输过程。它会根据请求的不同状态（未开始、正在交换数据大小、已完成交换）执行相应的操作：
 * - 如果请求未开始（used=0），则初始化状态
 * - 如果正在交换数据大小（used=1），则处理数据大小交换逻辑
 * - 如果已完成数据大小交换（used=2），则处理实际数据传输
 *
 * @param request 指向socket请求的指针
 * @param done 输出参数，指示请求是否完成（1=完成，0=未完成）
 * @param size 输出参数，返回传输的数据大小
 * @return scclResult_t 返回操作结果状态码
 */
scclResult_t scclNetSocketTest(void* request, int* done, int* size) {
    *done                          = 0;
    struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request;
    if(r == NULL) {
        INFO(SCCL_LOG_CODEALL, "NET/Socket : test called with NULL request");
        return scclInternalError;
    }
    INFO(SCCL_LOG_CODEALL, "NET/Socket : test called request used:%d\n", r->used);
    if(r->used == 1) { /* try to send/recv size */
        int data   = r->size;
        int offset = 0;
        SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset));

        if(offset == 0)
            return scclSuccess; /* Not ready -- retry later */

        // Not sure we could ever receive less than 4 bytes, but just in case ...
        if(offset < sizeof(int))
            SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset));

        // Check size is less or equal to the size provided by the user
        if(r->op == SCCL_SOCKET_RECV && data > r->size) {
            char line[SOCKET_NAME_MAXLEN + 1];
            union scclSocketAddress addr;
            scclSocketGetAddr(r->ctrlSock, &addr);
            WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \
          there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks",
                 scclSocketToString(&addr, line),
                 data,
                 r->size);
            return scclInvalidUsage;
        }
        r->size   = data;
        r->offset = 0;
        r->used   = 2; // done exchanging size
        // divide into subtasks
        int chunkOffset = 0, i = 0;
        if(r->comm->nSocks > 0) {
            // each request can be divided up to nSocks tasks
            int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
            while(chunkOffset < r->size) {
                int chunkSize = std::min(taskSize, r->size - chunkOffset);
                SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++));
                chunkOffset += chunkSize;
            }
        }
        r->nSubs = i;
    }
    if(r->used == 2) { // already exchanged size
        if(r->nSubs > 0) {
            int nCompleted = 0;
            for(int i = 0; i < r->nSubs; i++) {
                struct scclNetSocketTask* sub = r->tasks[i];
                if(sub->result != scclSuccess)
                    return sub->result;
                if(sub->offset == sub->size)
                    nCompleted++;
            }
            if(nCompleted == r->nSubs) {
                if(size)
                    *size = r->size;
                *done   = 1;
                r->used = 0;
                for(int i = 0; i < r->nSubs; i++) {
                    struct scclNetSocketTask* sub = r->tasks[i];
                    sub->used                     = 0;
                }
            }
        } else { // progress request using main thread
            if(r->offset < r->size) {
                SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset));
            }
            if(r->offset == r->size) {
                if(size)
                    *size = r->size;
                *done   = 1;
                r->used = 0;
            }
        }
    }
    return scclSuccess;
}

int main(int argc, char* argv[]) {
    struct scclNetSocketRequest* request = (struct scclNetSocketRequest*)malloc(sizeof(struct scclNetSocketRequest));
    request->op                          = SCCL_SOCKET_SEND;
    request->used                        = 1;
    request->size                        = 1024;
    request->data                        = (char*)malloc(request->size);
    request->ctrlSock                    = NULL;
    request->comm                        = NULL;
    request->nSubs                       = 0;
    int done;
    int sizes[32];
    printf("test\n");
    INFO(SCCL_LOG_CODEALL, "test INFO");
    SCCLCHECK(scclSocketInit(request));
    SCCLCHECK(scclNetSocketTest(request, &done, sizes));
    if(done) {
        printf("done\n");
    }
}