#include <unistd.h>
#include <sys/types.h>
#include <string.h>
#include "bootstrap_net.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

namespace bootstrap_net {
/* Init functions */
static char bootstrapNetIfName[MAX_IF_NAME_SIZE + 1];
static scclSocketAddress_t bootstrapNetIfAddr;
static int bootstrapNetInitDone  = 0;
pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER;

/**
 * @brief 初始化引导网络
 *
 * 该函数用于初始化SCCL的引导网络。它会检查环境变量"SCCL_COMM_ID"来获取远程地址，
 * 如果没有设置则自动查找可用的网络接口。函数使用互斥锁确保线程安全。
 *
 * @return scclResult_t 返回操作结果：
 *      - scclSuccess: 初始化成功
 *      - scclInvalidArgument: 无效的SCCL_COMM_ID格式
 *      - scclSystemError: 找不到匹配的网络接口
 *      - scclInternalError: 找不到可用的网络接口
 */
scclResult_t bootstrapNetInit() {
    if(bootstrapNetInitDone == 0) {
        pthread_mutex_lock(&bootstrapNetLock);
        if(bootstrapNetInitDone == 0) {
            char* env = getenv("SCCL_COMM_ID");
            if(env) {
                scclSocketAddress_t remoteAddr;
                if(net::host::scclSocketGetAddrFromString(&remoteAddr, env) != scclSuccess) {
                    WARN("Invalid SCCL_COMM_ID, please use format: <ipv4>:<port> or [<ipv6>]:<port> or <hostname>:<port>");
                    return scclInvalidArgument;
                }
                if(net::host::scclFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, MAX_IF_NAME_SIZE, 1) <= 0) {
                    WARN("NET/Socket : No usable listening interface found");
                    return scclSystemError;
                }
            } else {
                int nIfs = net::host::scclFindSocketInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, MAX_IF_NAME_SIZE, 1);
                if(nIfs <= 0) {
                    WARN("Bootstrap : no socket interface found");
                    return scclInternalError;
                }
            }
            char line[SOCKET_NAME_MAXLEN + MAX_IF_NAME_SIZE + 2];
            sprintf(line, " %s:", bootstrapNetIfName);
            net::host::scclSocketToString(&bootstrapNetIfAddr, line + strlen(line));
            INFO(SCCL_LOG_BOOTSTRAP, "Bootstrap : Using%s", line);
            bootstrapNetInitDone = 1;

            printf("line=%s\n", line);
        }
        pthread_mutex_unlock(&bootstrapNetLock);
    }
    return scclSuccess;
}

// Additional sync functions
/**
 * 通过网络发送数据
 *
 * @param sock 已连接的socket指针
 * @param data 要发送的数据指针
 * @param size 要发送的数据大小(字节)
 * @return scclResult_t 返回操作结果(scclSuccess表示成功)
 *
 * @note 先发送数据大小(sizeof(int))，再发送实际数据
 */
scclResult_t bootstrapNetSend(scclSocket_t* sock, void* data, int size) {
    SCCLCHECK(net::host::scclSocketSend(sock, &size, sizeof(int)));
    SCCLCHECK(net::host::scclSocketSend(sock, data, size));
    return scclSuccess;
}

/**
 * 从socket接收数据
 *
 * @param sock 要接收数据的socket
 * @param data 接收数据的缓冲区
 * @param size 缓冲区大小
 * @return scclResult_t 返回操作结果，成功返回scclSuccess，否则返回错误码
 *
 * @note 如果接收到的数据大小超过缓冲区大小，会截断数据并返回scclInternalError
 */
scclResult_t bootstrapNetRecv(scclSocket_t* sock, void* data, int size) {
    int recvSize;
    SCCLCHECK(net::host::scclSocketRecv(sock, &recvSize, sizeof(int)));
    if(recvSize > size) {
        WARN("Message truncated : received %d bytes instead of %d", recvSize, size);
        return scclInternalError;
    }
    SCCLCHECK(net::host::scclSocketRecv(sock, data, std::min(recvSize, size)));
    return scclSuccess;
}

} // namespace bootstrap_net

/**
 * 将未预期的连接请求加入队列
 *
 * @param state 引导状态指针
 * @param peer 对端节点ID
 * @param tag 连接标签
 * @param sock 套接字指针
 * @return 成功返回scclSuccess
 *
 * @note 该函数用于处理未预期的连接请求，将其加入等待队列
 */
scclResult_t unexpectedEnqueue(struct bootstrapState* state, int peer, int tag, scclSocket_t* sock) {
    // New unex
    struct unexConn* unex;
    SCCLCHECK(scclCalloc(&unex, 1));
    unex->peer = peer;
    unex->tag  = tag;
    memcpy(&unex->sock, sock, sizeof(scclSocket_t));

    // Enqueue
    struct unexConn* list = state->unexpectedConnections;
    if(list == NULL) {
        state->unexpectedConnections = unex;
        return scclSuccess;
    }
    while(list->next)
        list = list->next;
    list->next = unex;
    return scclSuccess;
}

/**
 * 从意外连接队列中查找并移除指定peer和tag的连接
 *
 * @param state 引导状态指针
 * @param peer 目标peer ID
 * @param tag 目标tag值
 * @param sock 输出参数，用于存储找到的socket
 * @param found 输出参数，指示是否找到匹配项
 * @return 总是返回scclSuccess
 *
 * @note 该函数会遍历意外连接链表，查找匹配peer和tag的连接，
 *       找到后将其从链表中移除并释放内存，通过sock参数返回socket信息
 */
scclResult_t unexpectedDequeue(struct bootstrapState* state, int peer, int tag, scclSocket_t* sock, int* found) {
    struct unexConn* elem = state->unexpectedConnections;
    struct unexConn* prev = NULL;
    *found                = 0;
    while(elem) {
        if(elem->peer == peer && elem->tag == tag) {
            if(prev == NULL) {
                state->unexpectedConnections = elem->next;
            } else {
                prev->next = elem->next;
            }
            memcpy(sock, &elem->sock, sizeof(scclSocket_t));
            free(elem);
            *found = 1;
            return scclSuccess;
        }
        prev = elem;
        elem = elem->next;
    }
    return scclSuccess;
}

/**
 * 释放未预期的连接链表
 *
 * 遍历并释放bootstrapState中存储的所有未预期连接
 *
 * @param state 包含未预期连接链表的状态结构体指针
 */
static void unexpectedFree(struct bootstrapState* state) {
    struct unexConn* elem = state->unexpectedConnections;
    struct unexConn* prev = NULL;

    while(elem) {
        prev = elem;
        elem = elem->next;
        free(prev);
    }
    return;
}

/**
 * 执行基于环的AllGather操作
 *
 * @param commState 通信状态指针
 * @param allData 用于收集所有rank数据的缓冲区
 * @param size 每个rank数据块的大小(字节)
 * @return 成功返回scclSuccess，失败返回错误码
 *
 * @note 该函数实现了一个简单的基于环的AllGather算法：
 *       1. 每个rank在步骤i从(rank-i-1)接收数据
 *       2. 将前一步骤从(rank-i)接收的数据发送给右侧rank
 *       3. 共进行nranks-1次步骤完成全收集
 */
scclResult_t bootstrapAllGather(void* commState, void* allData, int size) {
    struct bootstrapState* state = (struct bootstrapState*)commState;
    char* data                   = (char*)allData;
    int rank                     = state->rank;
    int nranks                   = state->nranks;

    INFO(SCCL_LOG_BOOTSTRAP, "rank %d nranks %d size %d", rank, nranks, size);

    /* Simple ring based AllGather
     * At each step i receive data from (rank-i-1) from left
     * and send previous step's data from (rank-i) to right
     */
    for(int i = 0; i < nranks - 1; i++) {
        size_t rslice = (rank - i - 1 + nranks) % nranks;
        size_t sslice = (rank - i + nranks) % nranks;

        // Send slice to the right
        SCCLCHECK(bootstrap_net::bootstrapNetSend(&state->ringSendSocket, data + sslice * size, size));
        // Recv slice from the left
        SCCLCHECK(bootstrap_net::bootstrapNetRecv(&state->ringRecvSocket, data + rslice * size, size));
    }

    INFO(SCCL_LOG_BOOTSTRAP, "rank %d nranks %d size %d - DONE", rank, nranks, size);
    return scclSuccess;
}

/**
 * 通过socket向指定对等节点发送数据
 *
 * @param commState 通信状态指针
 * @param peer 对等节点编号
 * @param tag 消息标签
 * @param data 要发送的数据指针
 * @param size 数据大小(字节)
 * @return scclResult_t 返回操作结果状态码(scclSuccess表示成功)
 */
scclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size) {
    scclResult_t ret             = scclSuccess;
    struct bootstrapState* state = (struct bootstrapState*)commState;
    scclSocket_t sock;

    SCCLCHECKGOTO(net::host::scclSocketInit(&sock, state->peerCommAddresses + peer, state->magic, net::host::scclSocketTypeBootstrap), ret, fail);
    SCCLCHECKGOTO(net::host::scclSocketConnect(&sock), ret, fail);
    SCCLCHECKGOTO(bootstrap_net::bootstrapNetSend(&sock, &state->rank, sizeof(int)), ret, fail);
    SCCLCHECKGOTO(bootstrap_net::bootstrapNetSend(&sock, &tag, sizeof(int)), ret, fail);
    SCCLCHECKGOTO(bootstrap_net::bootstrapNetSend(&sock, data, size), ret, fail);

exit:
    SCCLCHECK(net::host::scclSocketClose(&sock));
    return ret;
fail:
    goto exit;
}

/**
 * @brief 从指定对等节点接收数据
 *
 * 该函数首先检查未预期的连接队列，若找到匹配的(peer, tag)则直接接收数据。
 * 若未找到，则持续监听新连接，接收对等节点和标签信息进行匹配。
 * 若匹配成功则接收数据，否则将连接信息存入未预期队列供后续使用。
 *
 * @param commState 通信状态指针
 * @param peer 对等节点标识
 * @param tag 消息标签
 * @param data 接收数据缓冲区
 * @param size 接收数据大小
 * @return scclResult_t 返回操作结果(scclSuccess表示成功)
 */
scclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size) {
    scclResult_t ret             = scclSuccess;
    struct bootstrapState* state = (struct bootstrapState*)commState;
    scclSocket_t sock;
    int newPeer, newTag;

    // Search unexpected connections first
    int found;
    SCCLCHECK(unexpectedDequeue(state, peer, tag, &sock, &found));
    if(found) {
        SCCLCHECKGOTO(bootstrap_net::bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
        goto exit;
    }

    // Then look for new connections
    while(1) {
        SCCLCHECKGOTO(net::host::scclSocketInit(&sock), ret, fail);
        SCCLCHECKGOTO(net::host::scclSocketAccept(&sock, &state->listenSock), ret, fail);
        SCCLCHECKGOTO(bootstrap_net::bootstrapNetRecv(&sock, &newPeer, sizeof(int)), ret, fail);
        SCCLCHECKGOTO(bootstrap_net::bootstrapNetRecv(&sock, &newTag, sizeof(int)), ret, fail);
        if(newPeer == peer && newTag == tag) {
            SCCLCHECKGOTO(bootstrap_net::bootstrapNetRecv(&sock, ((char*)data), size), ret, fail);
            goto exit;
        }
        // Unexpected connection. Save for later.
        SCCLCHECKGOTO(unexpectedEnqueue(state, newPeer, newTag, &sock), ret, fail);
    }
exit:
    SCCLCHECK(net::host::scclSocketClose(&sock));
    return ret;
fail:
    goto exit;
}

scclResult_t bootstrapInit() {}

// /**
//  * @brief 初始化bootstrap网络通信
//  *
//  * 该函数负责初始化bootstrap网络通信环境，包括：
//  * 1. 创建监听socket供其他rank连接
//  * 2. 与root节点交换连接信息
//  * 3. 建立环形通信拓扑
//  * 4. 收集所有peer的通信地址
//  * 5. 创建并收集代理服务地址
//  *
//  * @param handle bootstrap句柄
//  * @param comm bootstrap通信上下文
//  * @return scclResult_t 返回操作结果，scclSuccess表示成功
//  */
// scclResult_t bootstrapInit(struct scclBootstrapHandle* handle, struct scclBootstrapComm* comm) {
//     int rank   = comm->rank;   // 当前进程的排名
//     int nranks = comm->nRanks; // 进程的总数

//     struct bootstrapState* state;      // 引导状态结构体
//     scclSocket_t* proxySocket;         // 代理套接字
//     scclSocketAddress_t nextAddr;      // 下一个地址
//     scclSocket_t sock, listenSockRoot; // 套接字和根监听套接字
//     struct extInfo info = {0};         // 扩展信息结构体

//     SCCLCHECK(scclCalloc(&state, 1));           // 分配引导状态结构体
//     state->rank      = rank;                    // 设置当前进程的排名
//     state->nranks    = nranks;                  // 设置进程的总数
//     state->abortFlag = comm->abortFlag;         // 设置中止标志
//     comm->bootstrap  = state;                   // 将引导状态结构体赋值给通信结构体
//     comm->magic = state->magic = handle->magic; // 设置魔术值

//     INFO(SCCL_LOG_BOOTSTRAP, "rank %d nranks %d", rank, nranks); // 打印日志信息

//     info.rank   = rank;   // 设置扩展信息结构体中的排名
//     info.nranks = nranks; // 设置扩展信息结构体中的进程总数
//     // 创建套接字供其他进程联系
//     SCCLCHECK(
//         net::host::scclSocketInit(&state->listenSock, &bootstrap_net::bootstrapNetIfAddr, comm->magic, net::host::scclSocketTypeBootstrap, comm->abortFlag));
//     SCCLCHECK(net::host::scclSocketListen(&state->listenSock));                          // 监听套接字
//     SCCLCHECK(net::host::scclSocketGetAddr(&state->listenSock, &info.extAddressListen)); // 获取监听套接字地址

//     // 创建套接字供根进程联系
//     SCCLCHECK(net::host::scclSocketInit(&listenSockRoot, &bootstrap_net::bootstrapNetIfAddr, comm->magic, net::host::scclSocketTypeBootstrap,
//     comm->abortFlag)); SCCLCHECK(net::host::scclSocketListen(&listenSockRoot));                              // 监听根进程套接字
//     SCCLCHECK(net::host::scclSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); // 获取根进程监听套接字地址

//     // // 分散连接时间以避免根进程过载
//     // if(nranks > 128) {
//     //     long msec = rank;
//     //     struct timespec tv;
//     //     tv.tv_sec  = msec / 1000;
//     //     tv.tv_nsec = 1000000 * (msec % 1000);
//     //     TRACE(SCCL_LOG_BOOTSTRAP, "rank %d delaying connection to root by %ld msec", rank, msec);
//     //     (void)nanosleep(&tv, NULL);
//     // }

//     // 向根进程发送我的监听套接字信息
//     SCCLCHECK(net::host::scclSocketInit(&sock, &handle->addr, comm->magic, net::host::scclSocketTypeBootstrap, comm->abortFlag));
//     SCCLCHECK(net::host::scclSocketConnect(&sock));                         // 连接套接字
//     SCCLCHECK(bootstrap_net::bootstrapNetSend(&sock, &info, sizeof(info))); // 发送扩展信息
//     SCCLCHECK(net::host::scclSocketClose(&sock));                           // 关闭套接字

//     // 从根进程获取我在引导环中的“下一个”进程的信息
//     SCCLCHECK(net::host::scclSocketInit(&sock));                                               // 初始化套接字
//     SCCLCHECK(net::host::scclSocketAccept(&sock, &listenSockRoot));                            // 接受根进程的连接
//     SCCLCHECK(bootstrap_net::bootstrapNetRecv(&sock, &nextAddr, sizeof(scclSocketAddress_t))); // 接收下一个地址
//     SCCLCHECK(net::host::scclSocketClose(&sock));                                              // 关闭套接字
//     SCCLCHECK(net::host::scclSocketClose(&listenSockRoot));                                    // 关闭根监听套接字

//     SCCLCHECK(net::host::scclSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, net::host::scclSocketTypeBootstrap, comm->abortFlag));
//     SCCLCHECK(net::host::scclSocketConnect(&state->ringSendSocket)); // 连接环发送套接字
//     // 接受引导环中前一个进程的连接请求
//     SCCLCHECK(net::host::scclSocketInit(&state->ringRecvSocket));                       // 初始化环接收套接字
//     SCCLCHECK(net::host::scclSocketAccept(&state->ringRecvSocket, &state->listenSock)); // 接受连接

//     // 全部收集所有监听处理器
//     SCCLCHECK(scclCalloc(&state->peerCommAddresses, nranks));                                     // 分配对等通信地址
//     SCCLCHECK(net::host::scclSocketGetAddr(&state->listenSock, state->peerCommAddresses + rank)); // 获取监听套接字地址
//     SCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(scclSocketAddress_t)));  // 全部收集地址

//     // 创建服务代理
//     SCCLCHECK(scclCalloc(&state->peerProxyAddresses, nranks)); // 分配对等代理地址

//     // 代理通过消息中止；不要设置中止标志
//     SCCLCHECK(scclCalloc(&proxySocket, 1)); // 分配代理套接字
//     SCCLCHECK(net::host::scclSocketInit(proxySocket, &bootstrap_net::bootstrapNetIfAddr, comm->magic, net::host::scclSocketTypeProxy, comm->abortFlag));
//     SCCLCHECK(net::host::scclSocketListen(proxySocket));                                          // 监听代理套接字
//     SCCLCHECK(net::host::scclSocketGetAddr(proxySocket, state->peerProxyAddresses + rank));       // 获取代理套接字地址
//     SCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(scclSocketAddress_t))); // 全部收集代理地址
//     // SCCLCHECK(scclProxyInit(comm, proxySocket, state->peerProxyAddresses));

//     INFO(SCCL_LOG_BOOTSTRAP, "rank %d nranks %d - DONE", rank, nranks); // 打印完成日志信息

//     return scclSuccess; // 返回成功
// }

// /**
//  * @brief 在bootstrap通信中创建新的子通信域
//  *
//  * 该函数用于将当前通信域按照指定颜色和键值拆分为子通信域，并建立相应的环状通信拓扑。
//  *
//  * @param handle bootstrap句柄
//  * @param comm 新创建的子通信域
//  * @param parent 父通信域
//  * @param color 用于划分通信域的颜色值
//  * @param key 用于确定新通信域中进程排名的键值
//  * @param parentRanks 父通信域中的进程排名映射
//  *
//  * @return scclResult_t 返回操作结果，成功返回scclSuccess
//  *
//  * @note 函数会建立环状通信拓扑，包括：
//  *       1. 初始化监听socket和环形接收socket
//  *       2. 与前后节点交换地址信息
//  *       3. 执行AllGather收集所有节点的通信地址
//  *       4. 根据配置决定是否共享代理状态或创建新的代理服务
//  */
// scclResult_t
// bootstrapSplit(struct scclBootstrapHandle* handle, struct scclBootstrapComm* comm, struct scclBootstrapComm* parent, int color, int key, int* parentRanks) {
//     scclResult_t ret = scclSuccess;
//     int rank         = comm->rank;
//     int nranks       = comm->nRanks;
//     int prev, next;
//     scclSocketAddress_t listenAddr, tmpAddr;
//     scclSocket_t* proxySocket;
//     struct bootstrapState* state;

//     // SCCLCHECKGOTO(scclCalloc(&state, 1), ret, fail);
//     // state->rank      = rank;
//     // state->nranks    = nranks;
//     // state->abortFlag = comm->abortFlag;
//     // comm->bootstrap  = state;
//     // comm->magic = state->magic = handle->magic;

//     // prev = parentRanks[(rank - 1 + nranks) % nranks];
//     // next = parentRanks[(rank + 1) % nranks];

//     // // Setup my sockets for the allgather ring and other p2p connections
//     // SCCLCHECKGOTO(
//     //     net::host::scclSocketInit(&state->listenSock, &bootstrap_net::bootstrapNetIfAddr, comm->magic, net::host::scclSocketTypeBootstrap,
//     comm->abortFlag,
//     //     0), ret, fail);
//     // SCCLCHECKGOTO(net::host::scclSocketInit(&state->ringRecvSocket, NULL, comm->magic, net::host::scclSocketTypeBootstrap, comm->abortFlag, 0), ret,
//     fail);

//     // // Create socket for other ranks to contact me
//     // SCCLCHECKGOTO(net::host::scclSocketListen(&state->listenSock), ret, fail);

//     // // Get addr from next rank
//     // SCCLCHECKGOTO(net::host::scclSocketGetAddr(&state->listenSock, &listenAddr), ret, fail);
//     // SCCLCHECKGOTO(bootstrapSend(parent->bootstrap, prev, -2, &listenAddr, sizeof(scclSocketAddress_t)), ret, fail);
//     // SCCLCHECKGOTO(bootstrapRecv(parent->bootstrap, next, -2, &tmpAddr, sizeof(scclSocketAddress_t)), ret, fail);

//     // SCCLCHECKGOTO(net::host::scclSocketInit(&state->ringSendSocket, &tmpAddr, comm->magic, net::host::scclSocketTypeBootstrap, comm->abortFlag, 0), ret,
//     // fail); SCCLCHECKGOTO(net::host::scclSocketConnect(&state->ringSendSocket), ret, fail);
//     // // Accept the connect request from the previous rank in the AllGather ring
//     // SCCLCHECKGOTO(net::host::scclSocketAccept(&state->ringRecvSocket, &state->listenSock), ret, fail);

//     // // AllGather all listen handlers
//     // SCCLCHECKGOTO(scclCalloc(&state->peerCommAddresses, nranks), ret, fail);
//     // memcpy(state->peerCommAddresses + rank, &listenAddr, sizeof(scclSocketAddress_t));
//     // SCCLCHECKGOTO(bootstrapAllGather(state, state->peerCommAddresses, sizeof(scclSocketAddress_t)), ret, fail);

//     // if(parent->splitShare) {
//     //     /* map local rank to top parent local rank. */
//     //     for(int i = 0; i < nranks; ++i) {
//     //         comm->topParentRanks[i] = parent->topParentRanks[parentRanks[i]];
//     //     }
//     //     comm->proxyState = parent->sharedRes->proxyState;
//     //     scclAtomicRefCountIncrement(&parent->sharedRes->proxyState->refCount);
//     // } else {
//     //     // Create the service proxy
//     //     SCCLCHECKGOTO(scclCalloc(&state->peerProxyAddresses, nranks), ret, fail);
//     //     SCCLCHECKGOTO(scclCalloc(&proxySocket, 1), ret, fail);
//     //     SCCLCHECKGOTO(
//     //         net::host::scclSocketInit(proxySocket, &bootstrap_net::bootstrapNetIfAddr, comm->magic, net::host::scclSocketTypeProxy, comm->abortFlag, 0),
//     //         ret,
//     //         fail);
//     //     SCCLCHECKGOTO(net::host::scclSocketListen(proxySocket), ret, fail);
//     //     SCCLCHECKGOTO(net::host::scclSocketGetAddr(proxySocket, &tmpAddr), ret, fail);
//     //     memcpy(state->peerProxyAddresses + rank, &tmpAddr, sizeof(scclSocketAddress_t));
//     //     SCCLCHECKGOTO(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(scclSocketAddress_t)), ret, fail);
//     //     // SCCLCHECKGOTO(scclProxyInit(comm, proxySocket, state->peerProxyAddresses), ret, fail);
//     // }

//     // INFO(sccl_INIT, "bootstrapSplit: rank %d nranks %d color %d key %d prev %d next %d - DONE", rank, nranks, color, key, prev, next);

// exit:
//     return ret;
// fail:
//     goto exit;
// }

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
