#pragma once

#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <unistd.h>
#include <pthread.h>
#include <string.h>
#include <stdarg.h>

#include <sys/syscall.h>

namespace sccl {

#define SCCL_NET_MAX_REQUESTS 8

typedef enum : uint8_t {
    SCCL_LOG_NONE    = 0,
    SCCL_LOG_VERSION = 1,
    SCCL_LOG_WARN    = 2,
    SCCL_LOG_INFO    = 3,
    SCCL_LOG_ABORT   = 4
} scclDebugLogLevel_t;

typedef enum : int64_t {
    SCCL_LOG_CODEALL    = ~0,
    SCCL_LOG_NET        = 0x0001,
    SCCL_LOG_TOPO       = 0x0002,
    SCCL_LOG_BOOTSTRAP  = 0x0004,
    SCCL_LOG_TRANSPORT  = 0x0008,
    SCCL_LOG_GRAPH      = 0x0010,
    SCCL_LOG_CONNECT    = 0x0020,
    SCCL_LOG_P2P        = 0x0040,
    SCCL_LOG_COLLECTIVE = 0x0080,
    SCCL_LOG_ALLOC      = 0x0100
} scclDebugLogSubSys_t;

namespace debug {

static char scclLastError[1024] = "";                             // 全局字符串，用于存储可读的最后错误信息
static char hostname[1024];                                       // 存储主机名的全局字符串
static pthread_mutex_t scclDebugLock = PTHREAD_MUTEX_INITIALIZER; // 用于调试操作的互斥锁，保证多线程环境下的线程安全
static __thread int tid              = -1;                        // 线程局部存储（Thread Local Storage）变量，存储当前线程的ID，默认值为-1
static int pid                       = -1;                        // 存储当前进程的ID，默认值为-1
static FILE* scclDebugFile           = stdout;                    // 指向调试输出流的文件指针，默认指向标准输出（stdout

static uint64_t scclDebugMask = SCCL_LOG_TOPO | SCCL_LOG_BOOTSTRAP; // Default debug sub-system mask is INIT and ENV
static int scclDebugLevel     = -1;                                 // 初始化为 -1，表示未设置

// 在文件顶部或适当位置定义变量
static int scclDebugPos = -1; // 初始化为 -1，表示未设置

/**
 * @brief 获取主机名并截断到指定分隔符
 *
 * 该函数获取当前主机名，并将其截断到第一个出现的指定分隔符处。
 * 如果获取主机名失败，则返回"unknown"。
 *
 * @param hostname 用于存储主机名的缓冲区
 * @param maxlen 缓冲区最大长度
 * @param delim 用于截断主机名的分隔符
 */
static void getHostName(char* hostname, int maxlen, const char delim) {
    if(gethostname(hostname, maxlen) != 0) {
        strncpy(hostname, "unknown", maxlen);
        return;
    }
    int i = 0;
    while((hostname[i] != delim) && (hostname[i] != '\0') && (i < maxlen - 1))
        i++;
    hostname[i] = '\0';
}

////////////////////////////// 初始化debug //////////////////////////////
/**
 * @brief 初始化SCCL调试系统
 *
 * 该函数负责初始化SCCL的调试功能，包括：
 * 1. 从环境变量SCCL_DEBUG_LEVEL读取并设置调试等级
 * 2. 从环境变量SCCL_DEBUG_POS读取并设置调试位置
 * 3. 缓存当前进程的PID和主机名
 * 4. 根据SCCL_DEBUG_FILE环境变量创建调试日志文件
 *
 * 函数使用互斥锁保证线程安全，并通过原子操作设置最终的调试等级和位置。
 * 调试等级和位置的默认值分别为SCCL_LOG_INFO和SCCL_LOG_CODEALL。
 *
 * @note 该函数是线程安全的，但应在程序早期调用以避免竞态条件
 */
static void scclDebugInit() {
    pthread_mutex_lock(&scclDebugLock);

    if(scclDebugLevel != -1) {
        pthread_mutex_unlock(&scclDebugLock);
        return;
    }

    //// 按照debug等级划分
    int tempScclDebugLevel = -1;
    {
        const char* sccl_debug = getenv("SCCL_DEBUG_LEVEL");

        if(sccl_debug == NULL) {
            tempScclDebugLevel = SCCL_LOG_INFO;
        } else if(strcasecmp(sccl_debug, "VERSION") == 0) {
            tempScclDebugLevel = SCCL_LOG_VERSION;
        } else if(strcasecmp(sccl_debug, "WARN") == 0) {
            tempScclDebugLevel = SCCL_LOG_WARN;
        } else if(strcasecmp(sccl_debug, "INFO") == 0) {
            tempScclDebugLevel = SCCL_LOG_INFO;
        } else if(strcasecmp(sccl_debug, "ABORT") == 0) {
            tempScclDebugLevel = SCCL_LOG_ABORT;
        }
    }

    //// 按照代码位置划分
    char* scclDebugSubsysEnv = getenv("SCCL_DEBUG_SUBSYS");
    if(scclDebugSubsysEnv != NULL) {
        int invert = 0;
        if(scclDebugSubsysEnv[0] == '^') {
            invert = 1;
            scclDebugSubsysEnv++;
        }
        scclDebugMask         = invert ? ~0ULL : 0ULL;
        char* scclDebugSubsys = strdup(scclDebugSubsysEnv);
        char* subsys          = strtok(scclDebugSubsys, ",");
        while(subsys != NULL) {
            uint64_t mask = 0;
            if(strcasecmp(subsys, "NET") == 0) {
                mask = SCCL_LOG_NET;
            } else if(strcasecmp(subsys, "TOPO") == 0) {
                mask = SCCL_LOG_TOPO;
            } else if(strcasecmp(subsys, "BOOTSTRAP") == 0) {
                mask = SCCL_LOG_BOOTSTRAP;
            } else if(strcasecmp(subsys, "TRANSPORT") == 0) {
                mask = SCCL_LOG_TRANSPORT;
            } else if(strcasecmp(subsys, "GRAPH") == 0) {
                mask = SCCL_LOG_GRAPH;
            } else if(strcasecmp(subsys, "CONNECT") == 0) {
                mask = SCCL_LOG_CONNECT;
            } else if(strcasecmp(subsys, "P2P") == 0) {
                mask = SCCL_LOG_P2P;
            } else if(strcasecmp(subsys, "COLLECTIVE") == 0) {
                mask = SCCL_LOG_COLLECTIVE;
            } else if(strcasecmp(subsys, "ALLOC") == 0) {
                mask = SCCL_LOG_ALLOC;
            } else if(strcasecmp(subsys, "ALL") == 0) {
                mask = SCCL_LOG_CODEALL;
            }
            if(mask) {
                if(invert)
                    scclDebugMask &= ~mask;
                else
                    scclDebugMask |= mask;
            }
            subsys = strtok(NULL, ",");
        }
        free(scclDebugSubsys);
    }

    // Cache pid and hostname
    getHostName(hostname, 1024, '.');
    pid = getpid();

    /* Parse and expand the SCCL_DEBUG_FILE path and
     * then create the debug file. But don't bother unless the
     * SCCL_DEBUG level is > VERSION
     */
    const char* scclDebugFileEnv = getenv("SCCL_DEBUG_FILE");
    if(tempScclDebugLevel > SCCL_LOG_VERSION && scclDebugFileEnv != NULL) {
        int c                      = 0;
        char debugFn[PATH_MAX + 1] = "";
        char* dfn                  = debugFn;
        while(scclDebugFileEnv[c] != '\0' && c < PATH_MAX) {
            if(scclDebugFileEnv[c++] != '%') {
                *dfn++ = scclDebugFileEnv[c - 1];
                continue;
            }
            switch(scclDebugFileEnv[c++]) {
                case '%': // Double %
                    *dfn++ = '%';
                    break;
                case 'h': // %h = hostname
                    dfn += snprintf(dfn, PATH_MAX, "%s", hostname);
                    break;
                case 'p': // %p = pid
                    dfn += snprintf(dfn, PATH_MAX, "%d", pid);
                    break;
                default: // Echo everything we don't understand
                    *dfn++ = '%';
                    *dfn++ = scclDebugFileEnv[c - 1];
                    break;
            }
        }
        *dfn = '\0';
        if(debugFn[0] != '\0') {
            FILE* file = fopen(debugFn, "w");
            if(file != nullptr) {
                setbuf(file, nullptr); // disable buffering
                scclDebugFile = file;
            }
        }
    }

    __atomic_store_n(&scclDebugLevel, tempScclDebugLevel, __ATOMIC_RELEASE);

    pthread_mutex_unlock(&scclDebugLock);
}

////////////////////////////// 打印DEBUG信息 //////////////////////////////

template <scclDebugLogLevel_t level>
void scclDebugLog(scclDebugLogSubSys_t pos_flags, const char* filepath, const char* filefunc, int line, const char* fmt, ...) {
    if(__atomic_load_n(&scclDebugLevel, __ATOMIC_ACQUIRE) == -1)
        scclDebugInit();

    if constexpr(level == SCCL_LOG_WARN)
        scclDebugPos = SCCL_LOG_CODEALL;

    // 检查调试级别和位置标志
    bool isDebugLevelSufficient = (scclDebugLevel >= level);
    bool isDebugPositionMatch   = (pos_flags & scclDebugMask) != 0;
    // 如果调试级别不足或位置标志不匹配，则不执行后续操作
    if(!isDebugLevelSufficient || !isDebugPositionMatch) {
        return;
    }

    // Save the last error (WARN) as a human readable string
    if constexpr(level == SCCL_LOG_WARN) {
        pthread_mutex_lock(&scclDebugLock);
        va_list vargs;
        va_start(vargs, fmt);
        (void)vsnprintf(scclLastError, sizeof(scclLastError), fmt, vargs);
        va_end(vargs);
        pthread_mutex_unlock(&scclDebugLock);
    }

    if(tid == -1) {
        tid = syscall(SYS_gettid);
    }

    char buffer[1024];
    size_t len = 0;
    if constexpr(level == SCCL_LOG_WARN) {
        len = snprintf(buffer, sizeof(buffer), "\n%s:%d:%d %s:%s:%d SCCL WARN ", hostname, pid, tid, filepath, filefunc, line);
    } else if constexpr(level == SCCL_LOG_INFO) {
        len = snprintf(buffer, sizeof(buffer), "%s:%d:%d %s:%s:%d SCCL INFO ", hostname, pid, tid, filepath, filefunc, line);
    }

    if(len) {
        va_list vargs;
        va_start(vargs, fmt);
        len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs);
        va_end(vargs);
        buffer[len++] = '\n';
        fwrite(buffer, 1, len, scclDebugFile);
    }
}

} // namespace debug

#define WARN(...) debug::scclDebugLog<SCCL_LOG_WARN>(SCCL_LOG_CODEALL, __FILE__, __func__, __LINE__, __VA_ARGS__)
#define INFO(FLAGS, ...) debug::scclDebugLog<SCCL_LOG_INFO>((FLAGS), __FILE__, __func__, __LINE__, __VA_ARGS__)

} // namespace sccl
