#include "rocm_smi_wrap.h"
#include "topo_utils.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

#define ROCMSMICHECK(cmd)                          \
    do {                                           \
        rsmi_status_t ret = cmd;                   \
        if(ret != RSMI_STATUS_SUCCESS) {           \
            const char* err;                       \
            rsmi_status_string(ret, &err);         \
            WARN("ROCm SMI init failure %s", err); \
            return scclInternalError;              \
        }                                          \
    } while(false)

/**
 * 初始化ROCm SMI库并获取版本信息
 *
 * @return scclSuccess 初始化成功
 * @note 该函数会打印ROCm SMI库的版本信息到日志
 */
scclResult_t rocm_smi_init() {
    ROCMSMICHECK(rsmi_init(0));
    rsmi_version_t version;
    ROCMSMICHECK(rsmi_version_get(&version));
    INFO(SCCL_LOG_TOPO, "rocm_smi_lib: version %d.%d.%d.%s", version.major, version.minor, version.patch, version.build);
    return scclSuccess;
}

/**
 * 获取系统中可用的ROCm设备数量
 *
 * @param num_devs 输出参数，用于存储获取到的设备数量
 * @return scclResult_t 返回操作结果，scclSuccess表示成功
 */
scclResult_t rocm_smi_getNumDevice(uint32_t* num_devs) {
    ROCMSMICHECK(rsmi_num_monitor_devices(num_devs));
    return scclSuccess;
}

scclResult_t rocm_smi_getDevicePciBusIdString(uint32_t deviceIndex, char* busId, size_t len) {
    uint64_t id;
    ROCMSMICHECK(rsmi_dev_pci_id_get(deviceIndex, &id));
    /** rocm_smi's bus ID format
     *  | Name     | Field   |
     *  ---------- | ------- |
     *  | Domain   | [64:32] |
     *  | Reserved | [31:16] |
     *  | Bus      | [15: 8] |
     *  | Device   | [ 7: 3] |
     *  | Function | [ 2: 0] |
     **/
    // snprintf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7));
    printf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7));
    return scclSuccess;
}

scclResult_t rocm_smi_getDeviceIndexByPciBusId(const char* pciBusId, uint32_t* deviceIndex) {
    uint32_t i, num_devs = 0;
    int64_t busid;

    busIdToInt64(pciBusId, &busid);
    /** convert to rocm_smi's bus ID format
     *  | Name     | Field   |
     *  ---------- | ------- |
     *  | Domain   | [64:32] |
     *  | Reserved | [31:16] |
     *  | Bus      | [15: 8] |
     *  | Device   | [ 7: 3] |
     *  | Function | [ 2: 0] |
     **/
    busid = ((busid & 0xffff00000L) << 12) + ((busid & 0xff000L) >> 4) + ((busid & 0xff0L) >> 1) + (busid & 0x7L);
    ROCMSMICHECK(rsmi_num_monitor_devices(&num_devs));
    for(i = 0; i < num_devs; i++) {
        uint64_t bdfid;
        ROCMSMICHECK(rsmi_dev_pci_id_get(i, &bdfid));
        if(bdfid == busid)
            break;
    }

    if(i < num_devs) {
        *deviceIndex = i;
        return scclSuccess;
    } else {
        WARN("rocm_smi_lib: %s device index not found", pciBusId);
        return scclInternalError;
    }
}

/**
 * 获取两个ROCm设备之间的链接信息
 *
 * @param srcIndex 源设备索引
 * @param dstIndex 目标设备索引
 * @param rsmi_type [out] 返回链接类型(RSMI_IO_LINK_TYPE)
 * @param hops [out] 返回跳数(默认为2，XGMI类型且权重为15时为1)
 * @param count [out] 返回链接计数(默认为1，XGMI类型时根据带宽计算)
 *
 * @return 成功返回scclSuccess，失败返回错误码
 *
 * @note 对于XGMI类型链接，当ROCm SMI版本>=2时，会根据最小/最大带宽计算链接计数
 */
scclResult_t rocm_smi_getLinkInfo(int srcIndex, int dstIndex, RSMI_IO_LINK_TYPE* rsmi_type, int* hops, int* count) {
    uint64_t rsmi_hops, rsmi_weight;
    ROCMSMICHECK(rsmi_topo_get_link_type(srcIndex, dstIndex, &rsmi_hops, rsmi_type));
    ROCMSMICHECK(rsmi_topo_get_link_weight(srcIndex, dstIndex, &rsmi_weight));
    *hops  = 2;
    *count = 1;
    if(*rsmi_type == RSMI_IOLINK_TYPE_XGMI && rsmi_weight == 15) {
        *hops = 1;
// #if defined USE_ROCM_SMI64CONFIG && rocm_smi_VERSION_MAJOR >= 2
#if 1
        uint64_t min_bw = 0, max_bw = 0;
        rsmi_version_t version;
        ROCMSMICHECK(rsmi_version_get(&version));
        if(version.major >= 2)
            ROCMSMICHECK(rsmi_minmax_bandwidth_get(srcIndex, dstIndex, &min_bw, &max_bw));
        if(max_bw && min_bw)
            *count = max_bw / min_bw;

        INFO(SCCL_LOG_GRAPH, "rocm smi srcIndex:%d dstIndex:%d min_bw:%ld max_bw:%ld count:%d", srcIndex, dstIndex, min_bw, max_bw, *count);
#endif
    }
    return scclSuccess;
}

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl
