#pragma once

#include <sys/mman.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include "check.h"
#include "align.h"
#include "asm_ops.h"

namespace sccl {
namespace alloc {

template <typename T>
/**
 * @brief 使用HIP分配并初始化主机内存（带调试信息）
 *
 * 该函数使用HIP API分配主机内存，并将内存初始化为0。支持HIP流捕获模式切换，
 * 并记录分配调试信息（文件/行号）。
 *
 * @tparam T 要分配的数据类型
 * @param[out] ptr 指向分配内存的指针
 * @param[in] nelem 要分配的元素数量
 * @param[in] filefunc 调用位置的文件/函数名（调试用）
 * @param[in] line 调用位置的行号（调试用）
 * @return scclResult_t 返回操作结果（scclSuccess表示成功）
 *
 * @note 分配失败时会输出警告日志，成功时会记录分配信息
 */
scclResult_t scclHipHostCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    *ptr                      = nullptr;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    HIPCHECKGOTO(hipHostMalloc(ptr, nelem * sizeof(T), hipHostMallocMapped), result, finish);
    memset(*ptr, 0, nelem * sizeof(T));
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(*ptr == nullptr)
        WARN("Failed to HIP host alloc %ld bytes", nelem * sizeof(T));
    INFO(SCCL_LOG_ALLOC, "%s:%d Hip Host Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
    return result;
}

inline scclResult_t scclHipHostFree(void* ptr) {
    HIPCHECK(hipHostFree(ptr));
    return scclSuccess;
}

/**
 * @brief 分配调试内存
 *
 * 为类型T分配指定数量的元素内存，并记录调试信息。
 *
 * @param[out] ptr 指向分配内存的指针的指针
 * @param[in] nelem 要分配的元素数量
 * @param[in] filefunc 调用位置的文件/函数信息
 * @param[in] line 调用位置的行号
 *
 * @return scclResult_t 返回操作结果，成功返回scclSuccess，失败返回scclSystemError
 */
template <typename T>
scclResult_t scclMallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
    void* p = malloc(nelem * sizeof(T));
    if(p == NULL) {
        WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
        return scclSystemError;
    }
    INFO(SCCL_LOG_ALLOC, "%s:%d malloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), p);
    *ptr = (T*)p;
    return scclSuccess;
}

/**
 * @brief 分配并清零指定数量的元素内存（调试版本）
 *
 * @tparam T 元素类型
 * @param[out] ptr 指向分配内存的指针
 * @param nelem 要分配的元素数量
 * @param filefunc 调用位置的文件/函数信息（用于调试）
 * @param line 调用位置的行号（用于调试）
 * @return scclResult_t 返回操作结果，scclSuccess表示成功
 *
 * @note 此函数会记录内存分配日志，并在失败时返回错误
 */
template <typename T>
scclResult_t scclCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
    void* p = malloc(nelem * sizeof(T));
    if(p == NULL) {
        WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
        return scclSystemError;
    }
    INFO(SCCL_LOG_ALLOC, "%s:%d malloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), p);
    memset(p, 0, nelem * sizeof(T));
    *ptr = (T*)p;
    return scclSuccess;
}

template <typename T>
scclResult_t scclRealloc(T** ptr, size_t oldNelem, size_t nelem) {
    if(nelem < oldNelem)
        return scclInternalError;
    if(nelem == oldNelem)
        return scclSuccess;

    T* oldp = *ptr;
    T* p    = (T*)malloc(nelem * sizeof(T));
    if(p == NULL) {
        WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
        return scclSystemError;
    }
    memcpy(p, oldp, oldNelem * sizeof(T));
    free(oldp);
    memset(p + oldNelem, 0, (nelem - oldNelem) * sizeof(T));
    *ptr = (T*)p;
    INFO(SCCL_LOG_ALLOC, "Mem Realloc old size %ld, new size %ld pointer %p", oldNelem * sizeof(T), nelem * sizeof(T), *ptr);
    return scclSuccess;
}

struct __attribute__((aligned(64))) allocationTracker {
    union {
        struct {
            uint64_t totalAlloc;
            uint64_t totalAllocSize;
        };
        char align[64];
    };
};
static_assert(sizeof(struct allocationTracker) == 64, "allocationTracker must be size of 64 bytes");
static constexpr int MAX_ALLOC_TRACK_NGPU = 32;
extern struct allocationTracker allocTracker[];

static int scclCuMemEnable() { return 0; }

static inline scclResult_t scclCuMemAlloc(void** ptr, void* handlep, size_t size) {
    WARN("CUMEM not supported prior to HIP 11.3");
    return scclInternalError;
}
static inline scclResult_t scclCuMemFree(void* ptr) {
    WARN("CUMEM not supported prior to HIP 11.3");
    return scclInternalError;
}

template <typename T>
/**
 * @brief 使用HIP分配设备内存（带调试信息）
 *
 * @tparam T 数据类型
 * @param filefunc 调用位置的文件/函数信息
 * @param line 调用位置的行号
 * @param[out] ptr 分配的内存指针
 * @param nelem 元素数量
 * @param isFineGrain 是否使用细粒度内存（默认为false）
 * @return scclResult_t 返回操作结果状态码
 *
 * @note 此函数会记录分配大小和指针地址的调试信息
 *       支持细粒度内存分配选项，并根据HIP_UNCACHED_MEMORY宏选择分配方式
 *       自动处理流捕获模式切换
 */
scclResult_t scclHipMallocDebug(const char* filefunc, int line, T** ptr, size_t nelem, bool isFineGrain = false) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    *ptr                      = nullptr;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(isFineGrain) {
#if defined(HIP_UNCACHED_MEMORY)
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocUncached), result, finish);
#else
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocFinegrained), result, finish);
#endif
    } else
        HIPCHECKGOTO(hipMalloc(ptr, nelem * sizeof(T)), result, finish);
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(*ptr == nullptr)
        WARN("Failed to HIP malloc %ld bytes", nelem * sizeof(T));
    INFO(SCCL_LOG_ALLOC, "%s:%d Hip Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
    return result;
}

template <typename T>
/**
 * @brief 使用HIP分配并清零设备内存（调试版本）
 *
 * @tparam T 数据类型
 * @param filefunc 调用源文件名/函数名（用于调试）
 * @param line 调用行号（用于调试）
 * @param[out] ptr 分配的设备内存指针
 * @param nelem 元素数量
 * @param sideStream 可选侧边流（避免干扰图捕获）
 * @param isFineGrain 是否使用细粒度内存
 * @return scclResult_t 返回操作结果状态码
 *
 * @note 1. 会自动跟踪分配统计
 *       2. 支持细粒度内存分配（需HSA支持）
 *       3. 使用异步方式清零内存
 *       4. 会临时修改流捕获模式
 */
scclResult_t scclHipCallocDebug(const char* filefunc, int line, T** ptr, size_t nelem, hipStream_t sideStream = nullptr, bool isFineGrain = false) {
    scclResult_t result = scclSuccess;
    extern bool hsaFineGrainFlag;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    *ptr                      = nullptr;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    // Need a side stream so as not to interfere with graph capture.
    hipStream_t stream = sideStream;
    if(stream == nullptr)
        HIPCHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
    if(isFineGrain && hsaFineGrainFlag) {
#if defined(HIP_UNCACHED_MEMORY)
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocUncached), result, finish);
#else
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocFinegrained), result, finish);
#endif
    } else
        HIPCHECKGOTO(hipMalloc(ptr, nelem * sizeof(T)), result, finish);
    HIPCHECKGOTO(hipMemsetAsync(*ptr, 0, nelem * sizeof(T), stream), result, finish);
    HIPCHECKGOTO(hipStreamSynchronize(stream), result, finish);
    if(sideStream == nullptr)
        HIPCHECKGOTO(hipStreamDestroy(stream), result, finish);
    int dev;
    HIPCHECK(hipGetDevice(&dev));
    if(dev < MAX_ALLOC_TRACK_NGPU) {
        asm_ops::add_ref_count_relaxed(&allocTracker[dev].totalAlloc, 1);
        asm_ops::add_ref_count_relaxed(&allocTracker[dev].totalAllocSize, nelem * sizeof(T));
    }
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(*ptr == nullptr)
        WARN("Failed to HIP calloc %ld bytes", nelem * sizeof(T));
    INFO(SCCL_LOG_ALLOC, "%s:%d Hip Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
    return result;
}

template <typename T>
/**
 * @brief 异步分配并初始化HIP设备内存（调试版本）
 *
 * 该函数用于在HIP设备上异步分配内存并将其初始化为0，支持细粒度内存分配选项。
 * 同时会跟踪内存分配情况并记录调试信息。
 *
 * @tparam T 数据类型
 * @param filefunc 调用位置的文件名和函数名（用于调试）
 * @param line 调用位置的行号（用于调试）
 * @param[out] ptr 指向分配内存的指针
 * @param nelem 要分配的元素数量
 * @param stream HIP流，用于异步操作
 * @param isFineGrain 是否使用细粒度内存分配（默认为false）
 * @return scclResult_t 返回操作结果（scclSuccess表示成功）
 *
 * @note 该函数会修改全局内存分配跟踪器，并记录分配日志
 * @warning 分配失败时会输出警告信息
 */
scclResult_t scclHipCallocAsyncDebug(const char* filefunc, int line, T** ptr, size_t nelem, hipStream_t stream, bool isFineGrain = false) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    *ptr                      = nullptr;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(isFineGrain) {
#if defined(HIP_UNCACHED_MEMORY)
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocUncached), result, finish);
#else
        HIPCHECKGOTO(hipExtMallocWithFlags((void**)ptr, nelem * sizeof(T), hipDeviceMallocFinegrained), result, finish);
#endif
    } else
        HIPCHECKGOTO(hipMalloc(ptr, nelem * sizeof(T)), result, finish);
    HIPCHECKGOTO(hipMemsetAsync(*ptr, 0, nelem * sizeof(T), stream), result, finish);
    int dev;
    HIPCHECK(hipGetDevice(&dev));
    if(dev < MAX_ALLOC_TRACK_NGPU) {
        asm_ops::add_ref_count_relaxed(&allocTracker[dev].totalAlloc, 1);
        asm_ops::add_ref_count_relaxed(&allocTracker[dev].totalAllocSize, nelem * sizeof(T));
    }
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(*ptr == nullptr)
        WARN("Failed to HIP calloc async %ld bytes", nelem * sizeof(T));
    INFO(SCCL_LOG_ALLOC, "%s:%d Hip Alloc Size %ld pointer %p", filefunc, line, nelem * sizeof(T), *ptr);
    return result;
}

template <typename T>
/**
 * 异步执行HIP内存拷贝操作
 *
 * @tparam T 数据类型模板参数
 * @param dst 目标内存地址
 * @param src 源内存地址
 * @param nelem 要拷贝的元素数量
 * @param stream HIP流对象
 * @return scclResult_t 返回操作结果，成功返回scclSuccess
 *
 * @note 此函数会临时修改流捕获模式为hipStreamCaptureModeRelaxed，
 *       并在操作完成后恢复原始模式
 */
scclResult_t scclHipMemcpyAsync(T* dst, T* src, size_t nelem, hipStream_t stream) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    HIPCHECKGOTO(hipMemcpyAsync(dst, src, nelem * sizeof(T), hipMemcpyDefault, stream), result, finish);
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    return result;
}

template <typename T>
/**
 * @brief 使用HIP在主机和设备间同步拷贝数据
 *
 * 该函数创建一个非阻塞流执行异步内存拷贝，并同步等待完成。
 * 使用hipStreamCaptureModeRelaxed模式避免干扰图捕获操作。
 *
 * @tparam T 数据类型模板参数
 * @param dst 目标内存地址
 * @param src 源内存地址
 * @param nelem 要拷贝的元素数量
 * @return scclResult_t 返回操作结果状态码
 */
scclResult_t scclHipMemcpy(T* dst, T* src, size_t nelem) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    // Need a side stream so as not to interfere with graph capture.
    hipStream_t stream;
    HIPCHECKGOTO(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking), result, finish);
    SCCLCHECKGOTO(scclHipMemcpyAsync(dst, src, nelem, stream), result, finish);
    HIPCHECKGOTO(hipStreamSynchronize(stream), result, finish);
    HIPCHECKGOTO(hipStreamDestroy(stream), result, finish);
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    return result;
}

template <typename T>
/**
 * @brief 释放HIP设备内存
 *
 * 该函数用于释放通过HIP分配的设备内存指针。支持两种模式：
 * 1. 当启用CUDA内存管理时，调用scclCuMemFree释放
 * 2. 否则直接调用hipFree释放
 *
 * @tparam T 指针类型
 * @param ptr 要释放的设备内存指针
 * @return scclResult_t 返回操作结果，scclSuccess表示成功
 *
 * @note 函数会在执行前后自动处理HIP流捕获模式
 */
scclResult_t scclHipFree(T* ptr) {
    scclResult_t result       = scclSuccess;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    INFO(SCCL_LOG_ALLOC, "Hip Free pointer %p", ptr);
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    if(scclCuMemEnable()) {
        SCCLCHECKGOTO(scclCuMemFree((void*)ptr), result, finish);
    } else {
        HIPCHECKGOTO(hipFree(ptr), result, finish);
    }
finish:
    HIPCHECK(hipThreadExchangeStreamCaptureMode(&mode));
    return result;
}

/**
 * @brief 分配对齐的内存并初始化为0（调试版本）
 *
 * 使用posix_memalign分配页面对齐的内存，并将内存区域初始化为0。
 * 记录分配信息到日志系统。
 *
 * @param[out] ptr 指向分配内存指针的指针
 * @param[in] size 请求分配的内存大小（字节）
 * @param[in] filefunc 调用位置的文件/函数信息（用于调试）
 * @param[in] line 调用位置的行号（用于调试）
 * @return scclResult_t 返回操作状态（scclSuccess或scclSystemError）
 */
inline scclResult_t scclIbMallocDebug(void** ptr, size_t size, const char* filefunc, int line) {
    size_t page_size = sysconf(_SC_PAGESIZE);
    void* p;
    int size_aligned = ROUNDUP(size, page_size);
    int ret          = posix_memalign(&p, page_size, size_aligned);
    if(ret != 0)
        return scclSystemError;
    memset(p, 0, size);
    *ptr = p;
    INFO(SCCL_LOG_ALLOC, "%s:%d Ib Alloc Size %ld pointer %p", filefunc, line, size, *ptr);
    return scclSuccess;
}

} // namespace alloc

// 定义宏 scclHipHostCalloc，用于调试版本的主机端内存分配，自动添加文件名和行号信息
#define scclHipHostCalloc(...) alloc::scclHipHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)

// 定义宏 scclCalloc，用于调试版本的常规内存分配，自动添加文件名和行号信息
#define scclMalloc(...) alloc::scclMallocDebug(__VA_ARGS__, __FILE__, __LINE__)

// 定义宏 scclCalloc，用于调试版本的常规内存分配，自动添加文件名和行号信息
#define scclCalloc(...) alloc::scclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)

// 定义宏 scclHipMalloc，用于调试版本的 HIP (Heterogeneous-Compute Interface for Portability) 内存分配，自动添加文件名和行号信息
#define scclHipMalloc(...) alloc::scclHipMallocDebug(__FILE__, __LINE__, __VA_ARGS__)

// 定义宏 scclHipCalloc，用于调试版本的 HIP 内存清零分配，自动添加文件名和行号信息
#define scclHipCalloc(...) alloc::scclHipCallocDebug(__FILE__, __LINE__, __VA_ARGS__)

// 定义宏 scclHipCallocAsync，用于调试版本的异步 HIP 内存清零分配，自动添加文件名和行号信息
#define scclHipCallocAsync(...) alloc::scclHipCallocAsyncDebug(__FILE__, __LINE__, __VA_ARGS__)

// 定义宏 scclIbMalloc，用于调试版本的 InfiniBand 内存分配，自动添加文件名和行号信息
#define scclIbMalloc(...) alloc::scclIbMallocDebug(__VA_ARGS__, __FILE__, __LINE__)

///////////////////////////////////////// 内存申请和释放函数 /////////////////////////////////////////

inline scclResult_t scclHipHostFree(void* ptr) { return alloc::scclHipFree(ptr); }

template <typename T>
scclResult_t scclRealloc(T** ptr, size_t oldNelem, size_t nelem) {
    return alloc::scclRealloc(ptr, oldNelem, nelem);
}

template <typename T>
scclResult_t scclHipMemcpyAsync(T* dst, T* src, size_t nelem, hipStream_t stream) {
    return alloc::scclHipMemcpyAsync(dst, src, nelem, stream);
}

template <typename T>
scclResult_t scclHipMemcpy(T* dst, T* src, size_t nelem) {
    return alloc::scclHipMemcpy(dst, src, nelem);
}

template <typename T>
scclResult_t scclHipFree(T* ptr) {
    return alloc::scclHipFree(ptr);
}

} // namespace sccl
