#include "rocm_smi_wrap.h" #include "topo_utils.h" namespace sccl { namespace hardware { namespace topology { namespace bootstrap { #define ROCMSMICHECK(cmd) \ do { \ rsmi_status_t ret = cmd; \ if(ret != RSMI_STATUS_SUCCESS) { \ const char* err; \ rsmi_status_string(ret, &err); \ WARN("ROCm SMI init failure %s", err); \ return scclInternalError; \ } \ } while(false) /** * 初始化ROCm SMI库并获取版本信息 * * @return scclSuccess 初始化成功 * @note 该函数会打印ROCm SMI库的版本信息到日志 */ scclResult_t rocm_smi_init() { ROCMSMICHECK(rsmi_init(0)); rsmi_version_t version; ROCMSMICHECK(rsmi_version_get(&version)); INFO(SCCL_LOG_TOPO, "rocm_smi_lib: version %d.%d.%d.%s", version.major, version.minor, version.patch, version.build); return scclSuccess; } /** * 获取系统中可用的ROCm设备数量 * * @param num_devs 输出参数,用于存储获取到的设备数量 * @return scclResult_t 返回操作结果,scclSuccess表示成功 */ scclResult_t rocm_smi_getNumDevice(uint32_t* num_devs) { ROCMSMICHECK(rsmi_num_monitor_devices(num_devs)); return scclSuccess; } scclResult_t rocm_smi_getDevicePciBusIdString(uint32_t deviceIndex, char* busId, size_t len) { uint64_t id; ROCMSMICHECK(rsmi_dev_pci_id_get(deviceIndex, &id)); /** rocm_smi's bus ID format * | Name | Field | * ---------- | ------- | * | Domain | [64:32] | * | Reserved | [31:16] | * | Bus | [15: 8] | * | Device | [ 7: 3] | * | Function | [ 2: 0] | **/ // snprintf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7)); printf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7)); return scclSuccess; } scclResult_t rocm_smi_getDeviceIndexByPciBusId(const char* pciBusId, uint32_t* deviceIndex) { uint32_t i, num_devs = 0; int64_t busid; busIdToInt64(pciBusId, &busid); /** convert to rocm_smi's bus ID format * | Name | Field | * ---------- | ------- | * | Domain | [64:32] | * | Reserved | [31:16] | * | Bus | [15: 8] | * | Device | [ 7: 3] | * | Function | [ 2: 0] | **/ busid = ((busid & 0xffff00000L) << 12) + ((busid & 0xff000L) >> 4) + ((busid & 0xff0L) >> 1) + (busid & 0x7L); ROCMSMICHECK(rsmi_num_monitor_devices(&num_devs)); for(i = 0; i < num_devs; i++) { uint64_t bdfid; ROCMSMICHECK(rsmi_dev_pci_id_get(i, &bdfid)); if(bdfid == busid) break; } if(i < num_devs) { *deviceIndex = i; return scclSuccess; } else { WARN("rocm_smi_lib: %s device index not found", pciBusId); return scclInternalError; } } /** * 获取两个ROCm设备之间的链接信息 * * @param srcIndex 源设备索引 * @param dstIndex 目标设备索引 * @param rsmi_type [out] 返回链接类型(RSMI_IO_LINK_TYPE) * @param hops [out] 返回跳数(默认为2,XGMI类型且权重为15时为1) * @param count [out] 返回链接计数(默认为1,XGMI类型时根据带宽计算) * * @return 成功返回scclSuccess,失败返回错误码 * * @note 对于XGMI类型链接,当ROCm SMI版本>=2时,会根据最小/最大带宽计算链接计数 */ scclResult_t rocm_smi_getLinkInfo(int srcIndex, int dstIndex, RSMI_IO_LINK_TYPE* rsmi_type, int* hops, int* count) { uint64_t rsmi_hops, rsmi_weight; ROCMSMICHECK(rsmi_topo_get_link_type(srcIndex, dstIndex, &rsmi_hops, rsmi_type)); ROCMSMICHECK(rsmi_topo_get_link_weight(srcIndex, dstIndex, &rsmi_weight)); *hops = 2; *count = 1; if(*rsmi_type == RSMI_IOLINK_TYPE_XGMI && rsmi_weight == 15) { *hops = 1; // #if defined USE_ROCM_SMI64CONFIG && rocm_smi_VERSION_MAJOR >= 2 #if 1 uint64_t min_bw = 0, max_bw = 0; rsmi_version_t version; ROCMSMICHECK(rsmi_version_get(&version)); if(version.major >= 2) ROCMSMICHECK(rsmi_minmax_bandwidth_get(srcIndex, dstIndex, &min_bw, &max_bw)); if(max_bw && min_bw) *count = max_bw / min_bw; INFO(SCCL_LOG_GRAPH, "rocm smi srcIndex:%d dstIndex:%d min_bw:%ld max_bw:%ld count:%d", srcIndex, dstIndex, min_bw, max_bw, *count); #endif } return scclSuccess; } } // namespace bootstrap } // namespace topology } // namespace hardware } // namespace sccl