#pragma once

#include "debug.h"
#include "check.h"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>

using namespace sccl;

struct netIf {
    char prefix[64];
    int port;
};
static thread_local int scclDebugNoWarn = 0;

#define SYSCHECK(call, name)             \
    do {                                 \
        int retval;                      \
        SYSCHECKVAL(call, name, retval); \
    } while(false)

#define SYSCHECKVAL(call, name, retval)                            \
    do {                                                           \
        SYSCHECKSYNC(call, name, retval);                          \
        if(retval == -1) {                                         \
            WARN("Call to " name " failed : %s", strerror(errno)); \
            return scclSystemError;                                \
        }                                                          \
    } while(false)
#define SYSCHECKSYNC(call, name, retval)                                                       \
    do {                                                                                       \
        retval = call;                                                                         \
        if(retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {      \
            INFO(SCCL_LOG_CODEALL, "Call to " name " returned %s, retrying", strerror(errno)); \
        } else {                                                                               \
            break;                                                                             \
        }                                                                                      \
    } while(true)
#define EQCHECK(statement, value)                                                                             \
    do {                                                                                                      \
        if((statement) == value) {                                                                            \
            /* Print the back trace*/                                                                         \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
            return scclSystemError;                                                                           \
        }                                                                                                     \
    } while(0);
#define NEQCHECKGOTO(statement, value, RES, label)                                                \
    do {                                                                                          \
        if((statement) != value) {                                                                \
            /* Print the back trace*/                                                             \
            RES = scclSystemError;                                                                \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                           \
        }                                                                                         \
    } while(0);
#define SYSCHECKGOTO(statement, RES, label)                                                       \
    do {                                                                                          \
        if((statement) == -1) {                                                                   \
            /* Print the back trace*/                                                             \
            RES = scclSystemError;                                                                \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                           \
        }                                                                                         \
    } while(0);
#define SCCLCHECKGOTO(call, RES, label)                                            \
    do {                                                                           \
        RES = call;                                                                \
        if(RES != scclSuccess && RES != scclInProgress) {                          \
            /* Print the back trace*/                                              \
            if(scclDebugNoWarn == 0)                                               \
                INFO(SCCL_LOG_CODEALL, "%s:%d -> %d", __FILE__, __LINE__, RES);    \
            goto label;                                                            \
        }                                                                          \
        INFO(SCCL_LOG_CODEALL, "check pass %s:%d -> %d", __FILE__, __LINE__, RES); \
    } while(0);
#define EQCHECKGOTO(statement, value, RES, label)                                                 \
    do {                                                                                          \
        if((statement) == value) {                                                                \
            /* Print the back trace*/                                                             \
            RES = scclSystemError;                                                                \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                           \
        }                                                                                         \
    } while(0);

static int parseStringList(const char* string, struct netIf* ifList, int maxList) {
    if(!string)
        return 0;

    const char* ptr = string;

    int ifNum = 0;
    int ifC   = 0;
    char c;
    do {
        c = *ptr;
        if(c == ':') {
            if(ifC > 0) {
                ifList[ifNum].prefix[ifC] = '\0';
                ifList[ifNum].port        = atoi(ptr + 1);
                ifNum++;
                ifC = 0;
            }
            while(c != ',' && c != '\0')
                c = *(++ptr);
        } else if(c == ',' || c == '\0') {
            if(ifC > 0) {
                ifList[ifNum].prefix[ifC] = '\0';
                ifList[ifNum].port        = -1;
                ifNum++;
                ifC = 0;
            }
        } else {
            ifList[ifNum].prefix[ifC] = c;
            ifC++;
        }
        ptr++;
    } while(ifNum < maxList && c);
    return ifNum;
}
static bool matchIf(const char* string, const char* ref, bool matchExact) {
    // Make sure to include '\0' in the exact case
    int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
    return strncmp(string, ref, matchLen) == 0;
}

static bool matchPort(const int port1, const int port2) {
    if(port1 == -1)
        return true;
    if(port2 == -1)
        return true;
    if(port1 == port2)
        return true;
    return false;
}

static bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
    // Make an exception for the case where no user list is defined
    if(listSize == 0)
        return true;

    for(int i = 0; i < listSize; i++) {
        if(matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
            return true;
        }
    }
    return false;
}

#define MAX_IFS 16
#define MAX_IF_NAME_SIZE 16
#define SLEEP_INT 1000          // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3  // connection timed out retry times (each one can take 20s)
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV)
#define SCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL

union scclSocketAddress {
    struct sockaddr sa;
    struct sockaddr_in sin;
    struct sockaddr_in6 sin6;
};

enum scclSocketState {
    scclSocketStateNone           = 0,
    scclSocketStateInitialized    = 1,
    scclSocketStateAccepting      = 2,
    scclSocketStateAccepted       = 3,
    scclSocketStateConnecting     = 4,
    scclSocketStateConnectPolling = 5,
    scclSocketStateConnected      = 6,
    scclSocketStateReady          = 7,
    scclSocketStateClosed         = 8,
    scclSocketStateError          = 9,
    scclSocketStateNum            = 10
};

enum scclSocketType {
    scclSocketTypeUnknown   = 0,
    scclSocketTypeBootstrap = 1,
    scclSocketTypeProxy     = 2,
    scclSocketTypeNetSocket = 3,
    scclSocketTypeNetIb     = 4
};
struct scclSocket {
    int fd;
    int acceptFd;
    int timedOutRetries;
    int refusedRetries;
    union scclSocketAddress addr;
    volatile uint32_t* abortFlag;
    int asyncFlag;
    enum scclSocketState state;
    int salen;
    uint64_t magic;
    enum scclSocketType type;
};

const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm = 1);
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair);
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs);
int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs);

// Initialize a socket
scclResult_t scclSocketInit(struct scclSocket* sock,
                            union scclSocketAddress* addr = NULL,
                            uint64_t magic                = SCCL_SOCKET_MAGIC,
                            enum scclSocketType type      = scclSocketTypeUnknown,
                            volatile uint32_t* abortFlag  = NULL,
                            int asyncFlag                 = 0);
// Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call
scclResult_t scclSocketListen(struct scclSocket* sock);
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr);
// Connect to sock->addr. sock->fd is set after a successful call.
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse = 0);
// Return socket connection state.
scclResult_t scclSocketReady(struct scclSocket* sock, int* running);
// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr.
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* ulistenSock);
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd);
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock);

#define SCCL_SOCKET_SEND 0
#define SCCL_SOCKET_RECV 1

scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking);
scclResult_t scclSocketClose(struct scclSocket* sock);
