#pragma once

#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <unistd.h>
#include <pthread.h>
#include <pwd.h>
#include <errno.h>

#include "debug.h"

#define SCCL_MAJOR 1
#define SCCL_MINOR 0
#define SCCL_PATCH 0
#define SCCL_SUFFIX ""

#define SCCL_VERSION(X, Y, Z) ((X) * 1000 + (Y) * 100 + (Z))

namespace sccl {
/**
 * @brief 对选中的代码进行简要功能说明
 * @note 根据代码作用域（如公开API或内部实现）编写适当的文档注释
 */
typedef enum {
    scclSuccess           = 0, /*!< 无错误 */
    scclUnhandledHipError = 1, /*!< 未处理的 HIP 错误 */
    scclSystemError       = 2, /*!< 未处理的系统错误 */
    scclInternalError     = 3, /*!< 内部错误 - 请报告给 RCCL 开发者 */
    scclInvalidArgument   = 4, /*!< 无效参数 */
    scclInvalidUsage      = 5, /*!< 无效使用 */
    scclRemoteError       = 6, /*!< 远程进程退出或发生网络错误 */
    scclInProgress        = 7, /*!< RCCL 操作正在进行中 */
    scclNumResults        = 8  /*!< 结果类型数量 */
} scclResult_t;

typedef enum {
    testSuccess       = 0,
    testInternalError = 1,
    testHipError      = 2,
    testScclError     = 3,
    testTimeout       = 4,
    testNumResults    = 5
} testResult_t;

static const char* scclGetErrorString(scclResult_t code) {
    switch(code) {
        case scclSuccess: return "success";
        case scclUnhandledHipError: return "unhandled hip error (run with SCCL_DEBUG=INFO for details)";
        case scclSystemError: return "unhandled system error (run with SCCL_DEBUG=INFO for details)";
        case scclInternalError: return "internal error - please report this issue to the SCCL developers";
        case scclInvalidArgument: return "invalid argument (run with SCCL_DEBUG=WARN for details)";
        case scclInvalidUsage: return "invalid usage (run with SCCL_DEBUG=WARN for details)";
        case scclRemoteError: return "remote process exited or there was a network error";
        case scclInProgress: return "SCCL operation in progress";
        default: return "unknown result code";
    }
}

////////////////////////////// SCCL和HIP //////////////////////////////

// Propagate errors up
#define SCCLCHECK(call)                                                        \
    do {                                                                       \
        scclResult_t RES = call;                                               \
        if(RES != scclSuccess && RES != scclInProgress) {                      \
            /* Print the back trace*/                                          \
            INFO(SCCL_LOG_CODEALL, "check fail: %s", scclGetErrorString(RES)); \
            return RES;                                                        \
        }                                                                      \
    } while(0);

#define SCCLCHECKGOTO(call, RES, label)                                                  \
    do {                                                                                 \
        RES = call;                                                                      \
        if(RES != scclSuccess && RES != scclInProgress) {                                \
            INFO(SCCL_LOG_CODEALL, "%s:%d %s -> %d", __func__, __LINE__, __FILE__, RES); \
            goto label;                                                                  \
        }                                                                                \
    } while(0);

#define HIPCHECK(cmd)                                                                                      \
    do {                                                                                                   \
        hipError_t err = cmd;                                                                              \
        if(err != hipSuccess) {                                                                            \
            char hostname[1024];                                                                           \
            gethostname(hostname, 1024);                                                                   \
            INFO(SCCL_LOG_CODEALL, "%s: Test HIP failure %s:%d '%s'\n", hostname, hipGetErrorString(err)); \
            return scclUnhandledHipError;                                                                  \
        }                                                                                                  \
    } while(0)

#define HIPCHECKGOTO(cmd, RES, label)                         \
    do {                                                      \
        hipError_t err = cmd;                                 \
        if(err != hipSuccess) {                               \
            WARN("HIP failure '%s'", hipGetErrorString(err)); \
            RES = scclUnhandledHipError;                      \
            goto label;                                       \
        }                                                     \
    } while(false)

////////////////////////////// Value检查 //////////////////////////////

#define EQCHECK(statement, value)                                                                                       \
    do {                                                                                                                \
        if((statement) == value) {                                                                                      \
            /* Print the back trace*/                                                                                   \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, scclSystemError, strerror(errno)); \
            return scclSystemError;                                                                                     \
        }                                                                                                               \
    } while(0);

#define EQCHECKGOTO(statement, value, RES, label)                                                           \
    do {                                                                                                    \
        if((statement) == value) {                                                                          \
            /* Print the back trace*/                                                                       \
            RES = scclSystemError;                                                                          \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                                     \
        }                                                                                                   \
    } while(0);

#define NEQCHECK(statement, value)                                                                                      \
    do {                                                                                                                \
        if((statement) != value) {                                                                                      \
            /* Print the back trace*/                                                                                   \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, scclSystemError, strerror(errno)); \
            return scclSystemError;                                                                                     \
        }                                                                                                               \
    } while(0);

#define NEQCHECKGOTO(statement, value, RES, label)                                                          \
    do {                                                                                                    \
        if((statement) != value) {                                                                          \
            /* Print the back trace*/                                                                       \
            RES = scclSystemError;                                                                          \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                                     \
        }                                                                                                   \
    } while(0);

#define LECHECK(statement, value)                                                                                       \
    do {                                                                                                                \
        if((statement) <= value) {                                                                                      \
            /* Print the back trace*/                                                                                   \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, scclSystemError, strerror(errno)); \
            return scclSystemError;                                                                                     \
        }                                                                                                               \
    } while(0);

#define LTCHECK(statement, value)                                                                                       \
    do {                                                                                                                \
        if((statement) < value) {                                                                                       \
            /* Print the back trace*/                                                                                   \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, scclSystemError, strerror(errno)); \
            return scclSystemError;                                                                                     \
        }                                                                                                               \
    } while(0);

////////////////////////////// SYS //////////////////////////////

// Check system calls
#define SYSCHECK(call, name)             \
    do {                                 \
        int retval;                      \
        SYSCHECKVAL(call, name, retval); \
    } while(false)

#define SYSCHECKVAL(call, name, retval)                            \
    do {                                                           \
        SYSCHECKSYNC(call, name, retval);                          \
        if(retval == -1) {                                         \
            WARN("Call to " name " failed : %s", strerror(errno)); \
            return scclSystemError;                                \
        }                                                          \
    } while(false)

#define SYSCHECKSYNC(call, name, retval)                                                       \
    do {                                                                                       \
        retval = call;                                                                         \
        if(retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {      \
            INFO(SCCL_LOG_CODEALL, "Call to " name " returned %s, retrying", strerror(errno)); \
        } else {                                                                               \
            break;                                                                             \
        }                                                                                      \
    } while(true)

#define SYSCHECKGOTO(statement, RES, label)                                                                 \
    do {                                                                                                    \
        if((statement) == -1) {                                                                             \
            /* Print the back trace*/                                                                       \
            RES = scclSystemError;                                                                          \
            INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __func__, __FILE__, __LINE__, RES, strerror(errno)); \
            goto label;                                                                                     \
        }                                                                                                   \
    } while(0);

} // namespace sccl
