#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include "ipcsocket.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {
// Enable Linux abstract socket naming
#define USE_ABSTRACT_SOCKET

#define SCCL_IPC_SOCKNAME_STR "/tmp/sccl-socket-%d-%lx"

/**
 * @brief 初始化IPC套接字
 *
 * 创建一个UNIX域数据报套接字，并绑定到指定路径。支持抽象套接字和普通文件系统套接字两种模式。
 *
 * @param handle 指向scclIpcSocket结构体的指针，用于存储套接字信息
 * @param rank 进程排名，用于生成唯一的套接字名称
 * @param hash 哈希值，与rank一起用于生成唯一的套接字名称
 * @param abortFlag 指向中止标志的指针，如果非NULL则设置套接字为非阻塞模式
 * @return scclResult_t 返回操作结果，成功返回scclSuccess，失败返回相应错误码
 */
scclResult_t scclIpcSocketInit(scclIpcSocket* handle, int rank, uint64_t hash, volatile uint32_t* abortFlag) {
    int fd = -1;
    struct sockaddr_un cliaddr;
    char temp[SCCL_IPC_SOCKNAME_LEN] = "";

    if(handle == NULL) {
        return scclInternalError;
    }

    handle->fd            = -1;
    handle->socketName[0] = '\0';
    if((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
        WARN("UDS: Socket creation error : %d", errno);
        return scclSystemError;
    }

    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;

    // Create unique name for the socket.
    int len = snprintf(temp, SCCL_IPC_SOCKNAME_LEN, SCCL_IPC_SOCKNAME_STR, rank, hash);
    if(len > (sizeof(cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot bind provided name to socket. Name too large");
        return scclInternalError;
    }
#ifndef USE_ABSTRACT_SOCKET
    unlink(temp);
#endif

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Creating socket %s", temp);

    strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
    cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif
    if(bind(fd, (struct sockaddr*)&cliaddr, sizeof(cliaddr)) < 0) {
        WARN("UDS: Binding to socket %s failed : %d", temp, errno);
        close(fd);
        return scclSystemError;
    }

    handle->fd = fd;
    strcpy(handle->socketName, temp);

    handle->abortFlag = abortFlag;
    // Mark socket as non-blocking
    if(handle->abortFlag) {
        int flags;
        EQCHECK(flags = fcntl(fd, F_GETFL), -1);
        SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
    }

    return scclSuccess;
}

/**
 * 关闭IPC套接字并释放相关资源
 *
 * @param handle 指向scclIpcSocket结构体的指针，包含要关闭的套接字信息
 * @return scclResult_t 返回操作结果：
 *         - scclSuccess: 操作成功完成
 *         - scclInternalError: 传入无效句柄(handle为NULL)
 *
 * @note 如果定义了USE_ABSTRACT_SOCKET宏，则不会删除socket文件
 *       如果套接字文件描述符无效(fd<=0)，函数会直接返回成功
 */
scclResult_t scclIpcSocketClose(scclIpcSocket* handle) {
    if(handle == NULL) {
        return scclInternalError;
    }
    if(handle->fd <= 0) {
        return scclSuccess;
    }
#ifndef USE_ABSTRACT_SOCKET
    if(handle->socketName[0] != '\0') {
        unlink(handle->socketName);
    }
#endif
    close(handle->fd);

    return scclSuccess;
}

/**
 * 通过IPC socket接收文件描述符
 *
 * @param handle 指向scclIpcSocket结构体的指针，包含socket相关信息
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果：
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用出错
 *         - scclInternalError: 操作被中断
 *
 * @note 该函数会阻塞等待直到接收到数据或发生错误
 * @warning 调用者需要确保recvFd指向有效的内存空间
 */
scclResult_t scclIpcSocketRecvFd(scclIpcSocket* handle, int* recvFd) {
    struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
    struct iovec iov[1];

    // Union to guarantee alignment requirements for control array
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    char dummy_buffer[1];
    int ret;

    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    iov[0].iov_base = (void*)dummy_buffer;
    iov[0].iov_len  = sizeof(dummy_buffer);

    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    while((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Receiving data over socket failed : %d", errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    if(((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
        if((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
            WARN("UDS: Receiving data over socket failed");
            return scclSystemError;
        }

        memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
    } else {
        WARN("UDS: Receiving data over socket %s failed", handle->socketName);
        return scclSystemError;
    }

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);

    return scclSuccess;
}

/**
 * 通过UNIX域套接字发送文件描述符
 *
 * @param handle      IPC套接字句柄
 * @param sendFd      要发送的文件描述符
 * @param rank        目标rank号
 * @param hash        用于生成套接字名的哈希值
 *
 * @return 成功返回scclSuccess，失败返回错误码:
 *         - scclInternalError: 内部错误(如名称过长或操作被中止)
 *         - scclSystemError:   系统调用错误
 *
 * @note 使用SCM_RIGHTS机制通过控制消息发送文件描述符
 *       在Linux下支持抽象套接字命名空间(当USE_ABSTRACT_SOCKET定义时)
 */
scclResult_t scclIpcSocketSendFd(scclIpcSocket* handle, const int sendFd, int rank, uint64_t hash) {
    struct msghdr msg;
    struct iovec iov[1];
    char temp[SCCL_IPC_SOCKNAME_LEN];

    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    struct sockaddr_un cliaddr;

    // Construct client address to send this shareable handle to
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;

    int len = snprintf(temp, SCCL_IPC_SOCKNAME_LEN, SCCL_IPC_SOCKNAME_STR, rank, hash);
    if(len > (sizeof(cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot connect to provided name for socket. Name too large");
        return scclInternalError;
    }
    (void)strncpy(cliaddr.sun_path, temp, len);

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Sending fd %d to UDS socket %s", sendFd, temp);

#ifdef USE_ABSTRACT_SOCKET
    cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif

    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    cmptr             = CMSG_FIRSTHDR(&msg);
    cmptr->cmsg_len   = CMSG_LEN(sizeof(int));
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type  = SCM_RIGHTS;

    memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));

    msg.msg_name    = (void*)&cliaddr;
    msg.msg_namelen = sizeof(struct sockaddr_un);

    iov[0].iov_base = (void*)"";
    iov[0].iov_len  = 1;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;
    msg.msg_flags   = 0;

    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Sending data over socket %s failed : %d", temp, errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    return scclSuccess;
}

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
