#include "socket.h"
#include <stdlib.h>

#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>

#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>

using namespace sccl;

static std::vector<std::pair<int, std::unordered_set<std::string>>> clientPortPool;

static scclResult_t socketProgressOpt(int op, struct scclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) {
    int bytes  = 0;
    *closed    = 0;
    char* data = (char*)ptr;
    char line[SOCKET_NAME_MAXLEN + 1];
    do {
        if(op == SCCL_SOCKET_RECV)
            bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
        if(op == SCCL_SOCKET_SEND)
            bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
        if(op == SCCL_SOCKET_RECV && bytes == 0) {
            *closed = 1;
            return scclSuccess;
        }
        if(bytes == -1) {
            if(errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
                WARN("socketProgressOpt: Call to recv from %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
                return scclRemoteError;
            } else {
                bytes = 0;
            }
        }
        (*offset) += bytes;
        if(sock->abortFlag && *sock->abortFlag != 0) {
            INFO(SCCL_LOG_CODEALL, "socketProgressOpt: abort called");
            return scclInternalError;
        }
    } while(bytes > 0 && (*offset) < size);
    return scclSuccess;
}

static scclResult_t socketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
    int closed;
    SCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed));
    if(closed) {
        char line[SOCKET_NAME_MAXLEN + 1];
        WARN("socketProgress: Connection closed by remote peer %s", scclSocketToString(&sock->addr, line, 0));
        return scclRemoteError;
    }
    return scclSuccess;
}

static scclResult_t socketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
    while(*offset < size)
        SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
    return scclSuccess;
}

/* Format a string representation of a (union scclSocketAddress *) socket address using getnameinfo()
 *
 * Output: "IPv4/IPv6 address<port>"
 */
const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) {
    if(buf == NULL || addr == NULL)
        return NULL;
    struct sockaddr* saddr = &addr->sa;
    if(saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
        buf[0] = '\0';
        return buf;
    }
    char host[NI_MAXHOST], service[NI_MAXSERV];
    /* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
     * (When not set, this will still happen in case the node's name cannot be determined.)
     */
    int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
    (void)getnameinfo(saddr, sizeof(union scclSocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
    sprintf(buf, "%s<%s>", host, service);
    return buf;
}

static uint16_t socketToPort(union scclSocketAddress* addr) {
    struct sockaddr* saddr = &addr->sa;
    return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
}

/* Allow the user to force the IPv4/IPv6 interface selection */
static int envSocketFamily(void) {
    int family = -1; // Family selection is not forced, will use first one found
    char* env  = getenv("SCCL_SOCKET_FAMILY");
    if(env == NULL)
        return family;

    INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_FAMILY set by environment to %s", env);

    if(strcmp(env, "AF_INET") == 0)
        family = AF_INET; // IPv4
    else if(strcmp(env, "AF_INET6") == 0)
        family = AF_INET6; // IPv6
    return family;
}

static int findInterfaces(const char* prefixList, char* names, union scclSocketAddress* addrs, int sock_family, int maxIfNameSize, int maxIfs) {
    struct netIf userIfs[MAX_IFS];
    bool searchNot = prefixList && prefixList[0] == '^';
    if(searchNot)
        prefixList++;
    bool searchExact = prefixList && prefixList[0] == '=';
    if(searchExact)
        prefixList++;
    int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);

    int found = 0;
    struct ifaddrs *interfaces, *interface;
    getifaddrs(&interfaces);
    for(interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
        if(interface->ifa_addr == NULL)
            continue;

        /* We only support IPv4 & IPv6 */
        int family = interface->ifa_addr->sa_family;
        if(family != AF_INET && family != AF_INET6)
            continue;

        /* Allow the caller to force the socket family type */
        if(sock_family != -1 && family != sock_family)
            continue;

        /* We also need to skip IPv6 loopback interfaces */
        if(family == AF_INET6) {
            struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
            if(IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr))
                continue;
        }

        // check against user specified interfaces
        if(!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
            continue;
        }

        // Check that this interface has not already been saved
        // getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
        bool duplicate = false;
        for(int i = 0; i < found; i++) {
            if(strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
                duplicate = true;
                break;
            }
        }

        if(!duplicate) {
            // Store the interface name
            strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
            // Store the IP address
            int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
            memcpy(addrs + found, interface->ifa_addr, salen);
            found++;
        }
    }

    freeifaddrs(interfaces);
    return found;
}

static bool matchSubnet(struct ifaddrs local_if, union scclSocketAddress* remote) {
    /* Check family first */
    int family = local_if.ifa_addr->sa_family;
    if(family != remote->sa.sa_family) {
        return false;
    }

    if(family == AF_INET) {
        struct sockaddr_in* local_addr  = (struct sockaddr_in*)(local_if.ifa_addr);
        struct sockaddr_in* mask        = (struct sockaddr_in*)(local_if.ifa_netmask);
        struct sockaddr_in& remote_addr = remote->sin;
        struct in_addr local_subnet, remote_subnet;
        local_subnet.s_addr  = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
        remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
        return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
    } else if(family == AF_INET6) {
        struct sockaddr_in6* local_addr  = (struct sockaddr_in6*)(local_if.ifa_addr);
        struct sockaddr_in6* mask        = (struct sockaddr_in6*)(local_if.ifa_netmask);
        struct sockaddr_in6& remote_addr = remote->sin6;
        struct in6_addr& local_in6       = local_addr->sin6_addr;
        struct in6_addr& mask_in6        = mask->sin6_addr;
        struct in6_addr& remote_in6      = remote_addr.sin6_addr;
        bool same                        = true;
        int len                          = 16; // IPv6 address is 16 unsigned char
        for(int c = 0; c < len; c++) {         // Network byte order is big-endian
            char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
            char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
            if(c1 ^ c2) {
                same = false;
                break;
            }
        }
        // At last, we need to compare scope id
        // Two Link-type addresses can have the same subnet address even though they are not in the same scope
        // For Global type, this field is 0, so a comparison wouldn't matter
        same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
        return same;
    } else {
        WARN("Net : Unsupported address family type");
        return false;
    }
}

int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
    char line_a[SOCKET_NAME_MAXLEN + 1];
    int found = 0;
    struct ifaddrs *interfaces, *interface;
    getifaddrs(&interfaces);
    for(interface = interfaces; interface && !found; interface = interface->ifa_next) {
        if(interface->ifa_addr == NULL)
            continue;

        /* We only support IPv4 & IPv6 */
        int family = interface->ifa_addr->sa_family;
        if(family != AF_INET && family != AF_INET6)
            continue;

        // check against user specified interfaces
        if(!matchSubnet(*interface, remoteAddr)) {
            continue;
        }

        // Store the local IP address
        int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
        memcpy(localAddrs + found, interface->ifa_addr, salen);

        // Store the interface name
        strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);

        found++;
        if(found == maxIfs)
            break;
    }

    if(found == 0) {
        WARN("Net : No interface found in the same subnet as remote address %s", scclSocketToString(remoteAddr, line_a));
    }
    freeifaddrs(interfaces);
    return found;
}

scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair) {
    if(!(ip_port_pair && strlen(ip_port_pair) > 1)) {
        WARN("Net : string is null");
        return scclInvalidArgument;
    }

    bool ipv6 = ip_port_pair[0] == '[';
    /* Construct the sockaddress structure */
    if(!ipv6) {
        struct netIf ni;
        // parse <ip_or_hostname>:<port> string, expect one pair
        if(parseStringList(ip_port_pair, &ni, 1) != 1) {
            WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
            return scclInvalidArgument;
        }

        struct addrinfo hints, *p;
        int rv;
        memset(&hints, 0, sizeof(hints));
        hints.ai_family   = AF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;

        if((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
            WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
            return scclInvalidArgument;
        }

        // use the first
        if(p->ai_family == AF_INET) {
            struct sockaddr_in& sin = ua->sin;
            memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
            sin.sin_family = AF_INET; // IPv4
            // inet_pton(AF_INET, ni.prefix, &(sin.sin_addr));  // IP address
            sin.sin_port = htons(ni.port); // port
        } else if(p->ai_family == AF_INET6) {
            struct sockaddr_in6& sin6 = ua->sin6;
            memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
            sin6.sin6_family   = AF_INET6;       // IPv6
            sin6.sin6_port     = htons(ni.port); // port
            sin6.sin6_flowinfo = 0;              // needed by IPv6, but possibly obsolete
            sin6.sin6_scope_id = 0;              // should be global scope, set to 0
        } else {
            WARN("Net : unsupported IP family");
            return scclInvalidArgument;
        }

        freeaddrinfo(p); // all done with this structure

    } else {
        int i, j = -1, len = strlen(ip_port_pair);
        for(i = 1; i < len; i++) {
            if(ip_port_pair[i] == '%')
                j = i;
            if(ip_port_pair[i] == ']')
                break;
        }
        if(i == len) {
            WARN("Net : No valid [IPv6]:port pair found");
            return scclInvalidArgument;
        }
        bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope

        char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
        memset(ip_str, '\0', sizeof(ip_str));
        memset(port_str, '\0', sizeof(port_str));
        memset(if_name, '\0', sizeof(if_name));
        strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
        strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
        int port = atoi(port_str);
        if(!global_scope)
            strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name

        struct sockaddr_in6& sin6 = ua->sin6;
        sin6.sin6_family          = AF_INET6;                            // IPv6
        inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr));                  // IP address
        sin6.sin6_port     = htons(port);                                // port
        sin6.sin6_flowinfo = 0;                                          // needed by IPv6, but possibly obsolete
        sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
    }
    return scclSuccess;
}

int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
    static int shownIfName = 0;
    int nIfs               = 0;
    // Allow user to force the INET socket family selection
    int sock_family = envSocketFamily();
    // User specified interface
    char* env = getenv("SCCL_SOCKET_IFNAME");
    if(env && strlen(env) > 1) {
        INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set by environment to %s", env);
        // Specified by user : find or fail
        if(shownIfName++ == 0)
            INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set to %s", env);
        nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
    } else {
        // Try to automatically pick the right one
        // Start with IB
        nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
        // else see if we can get some hint from COMM ID
        if(nIfs == 0) {
            char* commId = getenv("SCCL_COMM_ID");
            if(commId && strlen(commId) > 1) {
                INFO(SCCL_LOG_CODEALL, "SCCL_COMM_ID set by environment to %s", commId);
                // Try to find interface that is in the same subnet as the IP in comm id
                union scclSocketAddress idAddr;
                scclSocketGetAddrFromString(&idAddr, commId);
                nIfs = scclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
            }
        }
        // Then look for anything else (but not docker or lo)
        if(nIfs == 0)
            nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
        // Finally look for docker, then lo.
        if(nIfs == 0)
            nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
        if(nIfs == 0)
            nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
    }
    return nIfs;
}

scclResult_t scclSocketListen(struct scclSocket* sock) {
    if(sock == NULL) {
        WARN("scclSocketListen: pass NULL socket");
        return scclInvalidArgument;
    }
    if(sock->fd == -1) {
        WARN("scclSocketListen: file descriptor is -1");
        return scclInvalidArgument;
    }

    if(socketToPort(&sock->addr)) {
        // Port is forced by env. Make sure we get the port.
        int opt = 1;
#if defined(SO_REUSEPORT)
        SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#else
        SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#endif
    }

    // addr port should be 0 (Any port)
    SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind");

    /* Get the assigned Port */
    socklen_t size = sock->salen;
    SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname");

    /* Put the socket in listen mode
     * NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
     */
    SYSCHECK(listen(sock->fd, 16384), "listen");
    sock->state = scclSocketStateReady;
    return scclSuccess;
}

scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr) {
    if(sock == NULL) {
        WARN("scclSocketGetAddr: pass NULL socket");
        return scclInvalidArgument;
    }
    if(sock->state != scclSocketStateReady)
        return scclInternalError;
    memcpy(addr, &sock->addr, sizeof(union scclSocketAddress));
    return scclSuccess;
}

static scclResult_t socketTryAccept(struct scclSocket* sock) {
    socklen_t socklen = sizeof(union scclSocketAddress);
    sock->fd          = accept(sock->acceptFd, &sock->addr.sa, &socklen);
    if(sock->fd != -1) {
        sock->state = scclSocketStateAccepted;
    } else if(errno != EAGAIN && errno != EWOULDBLOCK) {
        WARN("socketTryAccept: Accept failed: %s", strerror(errno));
        return scclSystemError;
    }
    return scclSuccess;
}

static scclResult_t socketFinalizeAccept(struct scclSocket* sock) {
    uint64_t magic;
    enum scclSocketType type;
    int received  = 0;
    const int one = 1;
    SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");

    SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
    if(received == 0)
        return scclSuccess;
    SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
    if(magic != sock->magic) {
        WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
        close(sock->fd);
        sock->fd = -1;
        // Ignore spurious connection and accept again
        sock->state = scclSocketStateAccepting;
        return scclSuccess;
    } else {
        received = 0;
        SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &type, sizeof(type), &received));
        if(type != sock->type) {
            WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type);
            sock->state = scclSocketStateError;
            close(sock->fd);
            sock->fd = -1;
            return scclInternalError;
        } else {
            sock->state = scclSocketStateReady;
        }
    }
    return scclSuccess;
}

static scclResult_t socketStartConnect(struct scclSocket* sock) {
    /* blocking/non-blocking connect() is determined by asyncFlag. */
    int ret = connect(sock->fd, &sock->addr.sa, sock->salen);

    if(ret == 0) {
        sock->state = scclSocketStateConnected;
        return scclSuccess;
    } else if(errno == EINPROGRESS) {
        sock->state = scclSocketStateConnectPolling;
        return scclSuccess;
    } else if(errno == ECONNREFUSED) {
        if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
            sock->state = scclSocketStateError;
            WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries);
            return scclRemoteError;
        }
        usleep(SLEEP_INT);
        if(sock->refusedRetries % 1000 == 0)
            INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
        return scclSuccess;
    } else if(errno == ETIMEDOUT) {
        if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
            sock->state = scclSocketStateError;
            WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries);
            return scclRemoteError;
        }
        usleep(SLEEP_INT);
        return scclSuccess;
    } else {
        char line[SOCKET_NAME_MAXLEN + 1];
        sock->state = scclSocketStateError;
        WARN("socketStartConnect: Connect to %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
        return scclSystemError;
    }
}

static scclResult_t socketPollConnect(struct scclSocket* sock) {
    struct pollfd pfd;
    int timeout    = 1, ret;
    socklen_t rlen = sizeof(int);

    memset(&pfd, 0, sizeof(struct pollfd));
    pfd.fd     = sock->fd;
    pfd.events = POLLOUT;
    ret        = poll(&pfd, 1, timeout);

    if(ret == 0 || (ret < 0 && errno == EINTR)) {
        return scclSuccess;
    } else if(ret < 0) {
        WARN("socketPollConnect poll() failed with error %s", strerror(errno));
        return scclRemoteError;
    } else {
        EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0);
    }

    /* check socket status */
    SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt");

    if(ret == 0) {
        sock->state = scclSocketStateConnected;
    } else if(ret == ECONNREFUSED) {
        if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
            sock->state = scclSocketStateError;
            WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries);
            return scclRemoteError;
        }
        if(sock->refusedRetries % 1000 == 0)
            INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
        usleep(SLEEP_INT);
        sock->state = scclSocketStateConnecting;
    } else if(ret == ETIMEDOUT) {
        if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
            sock->state = scclSocketStateError;
            WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries);
            return scclRemoteError;
        }
        usleep(SLEEP_INT);
        sock->state = scclSocketStateConnecting;
    } else if(ret != EINPROGRESS) {
        sock->state = scclSocketStateError;
        return scclSystemError;
    }
    return scclSuccess;
}

scclResult_t scclSocketPollConnect(struct scclSocket* sock) {
    if(sock == NULL) {
        WARN("scclSocketPollConnect: pass NULL socket");
        return scclInvalidArgument;
    }
    SCCLCHECK(socketPollConnect(sock));
    return scclSuccess;
}

static scclResult_t socketFinalizeConnect(struct scclSocket* sock) {
    int sent = 0;
    SCCLCHECK(socketProgress(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
    if(sent == 0)
        return scclSuccess;
    SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
    sent = 0;
    SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
    sock->state = scclSocketStateReady;
    return scclSuccess;
}

static scclResult_t socketProgressState(struct scclSocket* sock) {
    if(sock->state == scclSocketStateAccepting) {
        SCCLCHECK(socketTryAccept(sock));
    }
    if(sock->state == scclSocketStateAccepted) {
        SCCLCHECK(socketFinalizeAccept(sock));
    }
    if(sock->state == scclSocketStateConnecting) {
        SCCLCHECK(socketStartConnect(sock));
    }
    if(sock->state == scclSocketStateConnectPolling) {
        SCCLCHECK(socketPollConnect(sock));
    }
    if(sock->state == scclSocketStateConnected) {
        SCCLCHECK(socketFinalizeConnect(sock));
    }
    return scclSuccess;
}

scclResult_t scclSocketReady(struct scclSocket* sock, int* running) {
    if(sock == NULL) {
        *running = 0;
        return scclSuccess;
    }
    if(sock->state == scclSocketStateError || sock->state == scclSocketStateClosed) {
        WARN("scclSocketReady: unexpected socket state %d", sock->state);
        return scclRemoteError;
    }
    *running = (sock->state == scclSocketStateReady) ? 1 : 0;
    if(*running == 0) {
        SCCLCHECK(socketProgressState(sock));
        *running = (sock->state == scclSocketStateReady) ? 1 : 0;
    }
    return scclSuccess;
}

scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse) {
    char line[SOCKET_NAME_MAXLEN + 1];
    const int one = 1;

    if(sock == NULL) {
        WARN("scclSocketConnect: pass NULL socket");
        return scclInvalidArgument;
    }
    if(sock->fd == -1) {
        WARN("scclSocketConnect: file descriptor is -1");
        return scclInvalidArgument;
    }

    if(sock->state != scclSocketStateInitialized) {
        WARN("scclSocketConnect: wrong socket state %d", sock->state);
        if(sock->state == scclSocketStateError)
            return scclRemoteError;
        return scclInternalError;
    }
    SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");

    if(portReuse) {
        int family = sock->addr.sa.sa_family;
        if(family != AF_INET && family != AF_INET6) {
            WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
                 scclSocketToString(&sock->addr, line),
                 family,
                 AF_INET,
                 AF_INET6);
            return scclInternalError;
        }
        int salen = (family == AF_INET) ? sizeof(struct sockaddr_in)
                                        : sizeof(struct sockaddr_in6); // pre-define ports according to tid, to avoid extra lock for race condition

        if(clientPortPool.size() == 0) {
            for(int tid = syscall(SYS_gettid), i = 1; i < 5; i++) {
                clientPortPool.push_back(std::make_pair(60000 + i * 1000 + tid % 1000, std::unordered_set<std::string>()));
            }
        }
        // find a port without conflict (different remote peer) in best effort
        int reused_port = -1;
        std::string remote_peer(scclSocketToString(&sock->addr, line));
        for(auto& port : clientPortPool) {
            if(port.second.find(remote_peer) == port.second.end()) {
                reused_port = port.first;
                port.second.insert(remote_peer);
                break;
            }
        }
        // bind the port in fd for connect system call
        if(reused_port != -1) {
            int opt = 1;
            SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
            struct sockaddr_in sin;
            sin.sin_family      = family;
            sin.sin_addr.s_addr = htonl(INADDR_ANY);
            sin.sin_port        = htons(reused_port);
            SYSCHECK(bind(sock->fd, (struct sockaddr*)&sin, salen), "bind_client_port");
        }
    }

    sock->state = scclSocketStateConnecting;
    do {
        SCCLCHECK(socketProgressState(sock));
    } while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
            (sock->state == scclSocketStateConnecting || sock->state == scclSocketStateConnectPolling || sock->state == scclSocketStateConnected));

    if(sock->abortFlag && *sock->abortFlag != 0)
        return scclInternalError;

    switch(sock->state) {
        case scclSocketStateConnecting:
        case scclSocketStateConnectPolling:
        case scclSocketStateConnected:
        case scclSocketStateReady: return scclSuccess;
        case scclSocketStateError: return scclSystemError;
        default: WARN("scclSocketConnect: wrong socket state %d", sock->state); return scclInternalError;
    }
}

scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* listenSock) {
    scclResult_t ret = scclSuccess;

    if(listenSock == NULL || sock == NULL) {
        WARN("scclSocketAccept: pass NULL socket");
        ret = scclInvalidArgument;
        goto exit;
    }
    if(listenSock->state != scclSocketStateReady) {
        WARN("scclSocketAccept: wrong socket state %d", listenSock->state);
        if(listenSock->state == scclSocketStateError)
            ret = scclSystemError;
        else
            ret = scclInternalError;
        goto exit;
    }

    if(sock->acceptFd == -1) {
        memcpy(sock, listenSock, sizeof(struct scclSocket));
        sock->acceptFd = listenSock->fd;
        sock->state    = scclSocketStateAccepting;
    }

    do {
        SCCLCHECKGOTO(socketProgressState(sock), ret, exit);
    } while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
            (sock->state == scclSocketStateAccepting || sock->state == scclSocketStateAccepted));

    if(sock->abortFlag && *sock->abortFlag != 0)
        return scclInternalError;

    switch(sock->state) {
        case scclSocketStateAccepting:
        case scclSocketStateAccepted:
        case scclSocketStateReady: ret = scclSuccess; break;
        case scclSocketStateError: ret = scclSystemError; break;
        default:
            WARN("scclSocketAccept: wrong socket state %d", sock->state);
            ret = scclInternalError;
            break;
    }

exit:
    return ret;
}

scclResult_t
scclSocketInit(struct scclSocket* sock, union scclSocketAddress* addr, uint64_t magic, enum scclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) {
    scclResult_t ret = scclSuccess;

    if(sock == NULL)
        goto exit;
    sock->timedOutRetries = 0;
    sock->refusedRetries  = 0;
    sock->abortFlag       = abortFlag;
    sock->asyncFlag       = asyncFlag;
    sock->state           = scclSocketStateInitialized;
    sock->magic           = magic;
    sock->type            = type;
    sock->fd              = -1;
    sock->acceptFd        = -1;

    if(addr) {
        /* IPv4/IPv6 support */
        int family;
        memcpy(&sock->addr, addr, sizeof(union scclSocketAddress));
        family = sock->addr.sa.sa_family;
        if(family != AF_INET && family != AF_INET6) {
            char line[SOCKET_NAME_MAXLEN + 1];
            WARN("scclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
                 scclSocketToString(&sock->addr, line),
                 family,
                 AF_INET,
                 AF_INET6);
            ret = scclInternalError;
            goto fail;
        }
        sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);

        /* Connect to a hostname / port */
        sock->fd = socket(family, SOCK_STREAM, 0);
        if(sock->fd == -1) {
            WARN("scclSocketInit: Socket creation failed : %s", strerror(errno));
            ret = scclSystemError;
            goto fail;
        }
    } else {
        memset(&sock->addr, 0, sizeof(union scclSocketAddress));
    }

    /* Set socket as non-blocking if async or if we need to be able to abort */
    if((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) {
        int flags;
        EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail);
        SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), ret, fail);
    }

exit:
    return ret;
fail:
    goto exit;
}

scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
    if(sock == NULL) {
        WARN("scclSocketProgress: pass NULL socket");
        return scclInvalidArgument;
    }
    SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
    return scclSuccess;
}

scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
    if(sock == NULL) {
        WARN("scclSocketWait: pass NULL socket");
        return scclInvalidArgument;
    }
    SCCLCHECK(socketWait(op, sock, ptr, size, offset));
    return scclSuccess;
}

scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size) {
    int offset = 0;
    if(sock == NULL) {
        WARN("scclSocketSend: pass NULL socket");
        return scclInvalidArgument;
    }
    if(sock->state != scclSocketStateReady) {
        WARN("scclSocketSend: socket state (%d) is not ready", sock->state);
        return scclInternalError;
    }
    SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, ptr, size, &offset));
    return scclSuccess;
}

scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size) {
    int offset = 0;
    if(sock == NULL) {
        WARN("scclSocketRecv: pass NULL socket");
        return scclInvalidArgument;
    }
    if(sock->state != scclSocketStateReady) {
        WARN("scclSocketRecv: socket state (%d) is not ready", sock->state);
        return scclInternalError;
    }
    SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, ptr, size, &offset));
    return scclSuccess;
}

// Receive or detect connection closed
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking) {
    int offset = 0;
    if(sock == NULL) {
        WARN("scclSocketTryRecv: pass NULL socket");
        return scclInvalidArgument;
    }
    *closed = 0;
    // Block until connection closes or nbytes received
    if(blocking) {
        while(offset < size) {
            SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
            if(*closed)
                return scclSuccess;
        }
    } else {
        SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
        if(*closed)
            return scclSuccess;

        // If any bytes were received, block waiting for the rest
        if(offset > 0) {
            while(offset < size) {
                SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
                if(*closed)
                    return scclSuccess;
            }
            // No bytes were received, return scclInProgress
        } else {
            return scclInProgress;
        }
    }
    return scclSuccess;
}

scclResult_t scclSocketClose(struct scclSocket* sock) {
    if(sock != NULL) {
        if(sock->fd >= 0) {
            /* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected
             * by refcount of fd, but close() is. close() won't close a fd and send FIN packet if
             * the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful
             * connection close here. */
            shutdown(sock->fd, SHUT_RDWR);
            close(sock->fd);
        }
        sock->state = scclSocketStateClosed;
        sock->fd    = -1;
    }
    return scclSuccess;
}

scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd) {
    if(sock == NULL) {
        WARN("scclSocketGetFd: pass NULL socket");
        return scclInvalidArgument;
    }
    if(fd)
        *fd = sock->fd;
    return scclSuccess;
}

scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock) {
    if(sock == NULL) {
        WARN("scclSocketGetFd: pass NULL socket");
        return scclInvalidArgument;
    }
    sock->fd = fd;
    return scclSuccess;
}
