rocm_smi_wrap.cpp 4.6 KB
Newer Older
lishen's avatar
lishen committed
1
#include "rocm_smi_wrap.h"
2
#include "topo_utils.h"
lishen's avatar
lishen committed
3
4

namespace sccl {
5
6
7
namespace hardware {
namespace topology {
namespace bootstrap {
lishen's avatar
lishen committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

#define ROCMSMICHECK(cmd)                          \
    do {                                           \
        rsmi_status_t ret = cmd;                   \
        if(ret != RSMI_STATUS_SUCCESS) {           \
            const char* err;                       \
            rsmi_status_string(ret, &err);         \
            WARN("ROCm SMI init failure %s", err); \
            return scclInternalError;              \
        }                                          \
    } while(false)

/**
 * 初始化ROCm SMI库并获取版本信息
 *
 * @return scclSuccess 初始化成功
 * @note 该函数会打印ROCm SMI库的版本信息到日志
 */
scclResult_t rocm_smi_init() {
    ROCMSMICHECK(rsmi_init(0));
    rsmi_version_t version;
    ROCMSMICHECK(rsmi_version_get(&version));
    INFO(SCCL_LOG_TOPO, "rocm_smi_lib: version %d.%d.%d.%s", version.major, version.minor, version.patch, version.build);
    return scclSuccess;
}

/**
 * 获取系统中可用的ROCm设备数量
 *
 * @param num_devs 输出参数,用于存储获取到的设备数量
 * @return scclResult_t 返回操作结果,scclSuccess表示成功
 */
scclResult_t rocm_smi_getNumDevice(uint32_t* num_devs) {
    ROCMSMICHECK(rsmi_num_monitor_devices(num_devs));
    return scclSuccess;
}

scclResult_t rocm_smi_getDevicePciBusIdString(uint32_t deviceIndex, char* busId, size_t len) {
    uint64_t id;
    ROCMSMICHECK(rsmi_dev_pci_id_get(deviceIndex, &id));
    /** rocm_smi's bus ID format
     *  | Name     | Field   |
     *  ---------- | ------- |
     *  | Domain   | [64:32] |
     *  | Reserved | [31:16] |
     *  | Bus      | [15: 8] |
     *  | Device   | [ 7: 3] |
     *  | Function | [ 2: 0] |
     **/
    // snprintf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7));
    printf(busId, len, "%04lx:%02lx:%02lx.%01lx", (id) >> 32, (id & 0xff00) >> 8, (id & 0xf8) >> 3, (id & 0x7));
    return scclSuccess;
}

scclResult_t rocm_smi_getDeviceIndexByPciBusId(const char* pciBusId, uint32_t* deviceIndex) {
    uint32_t i, num_devs = 0;
    int64_t busid;

    busIdToInt64(pciBusId, &busid);
    /** convert to rocm_smi's bus ID format
     *  | Name     | Field   |
     *  ---------- | ------- |
     *  | Domain   | [64:32] |
     *  | Reserved | [31:16] |
     *  | Bus      | [15: 8] |
     *  | Device   | [ 7: 3] |
     *  | Function | [ 2: 0] |
     **/
    busid = ((busid & 0xffff00000L) << 12) + ((busid & 0xff000L) >> 4) + ((busid & 0xff0L) >> 1) + (busid & 0x7L);
    ROCMSMICHECK(rsmi_num_monitor_devices(&num_devs));
    for(i = 0; i < num_devs; i++) {
        uint64_t bdfid;
        ROCMSMICHECK(rsmi_dev_pci_id_get(i, &bdfid));
        if(bdfid == busid)
            break;
    }

    if(i < num_devs) {
        *deviceIndex = i;
        return scclSuccess;
    } else {
        WARN("rocm_smi_lib: %s device index not found", pciBusId);
        return scclInternalError;
    }
}

/**
 * 获取两个ROCm设备之间的链接信息
 *
 * @param srcIndex 源设备索引
 * @param dstIndex 目标设备索引
 * @param rsmi_type [out] 返回链接类型(RSMI_IO_LINK_TYPE)
 * @param hops [out] 返回跳数(默认为2,XGMI类型且权重为15时为1)
 * @param count [out] 返回链接计数(默认为1,XGMI类型时根据带宽计算)
 *
 * @return 成功返回scclSuccess,失败返回错误码
 *
 * @note 对于XGMI类型链接,当ROCm SMI版本>=2时,会根据最小/最大带宽计算链接计数
 */
scclResult_t rocm_smi_getLinkInfo(int srcIndex, int dstIndex, RSMI_IO_LINK_TYPE* rsmi_type, int* hops, int* count) {
    uint64_t rsmi_hops, rsmi_weight;
    ROCMSMICHECK(rsmi_topo_get_link_type(srcIndex, dstIndex, &rsmi_hops, rsmi_type));
    ROCMSMICHECK(rsmi_topo_get_link_weight(srcIndex, dstIndex, &rsmi_weight));
    *hops  = 2;
    *count = 1;
    if(*rsmi_type == RSMI_IOLINK_TYPE_XGMI && rsmi_weight == 15) {
        *hops = 1;
// #if defined USE_ROCM_SMI64CONFIG && rocm_smi_VERSION_MAJOR >= 2
#if 1
        uint64_t min_bw = 0, max_bw = 0;
        rsmi_version_t version;
        ROCMSMICHECK(rsmi_version_get(&version));
        if(version.major >= 2)
            ROCMSMICHECK(rsmi_minmax_bandwidth_get(srcIndex, dstIndex, &min_bw, &max_bw));
        if(max_bw && min_bw)
            *count = max_bw / min_bw;

        INFO(SCCL_LOG_GRAPH, "rocm smi srcIndex:%d dstIndex:%d min_bw:%ld max_bw:%ld count:%d", srcIndex, dstIndex, min_bw, max_bw, *count);
#endif
    }
    return scclSuccess;
}

131
132
133
} // namespace bootstrap
} // namespace topology
} // namespace hardware
lishen's avatar
lishen committed
134
} // namespace sccl