run_hipblaslt_bench.sh 8.68 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/bin/bash

# ROCm BLAS 性能测试脚本
# 测试多种矩阵形状和数据类型的GEMM性能

# 设置使用的GPU设备
export HIP_VISIBLE_DEVICES=7
chmod +x ${ROCM_PATH}/bin/hipblaslt-bench
HIPBLASLT_TUNING_OVERRIDE_FILE=hipblaslt.config 

# 输出文件
OUTPUT_CSV="hipblaslt_bench_results.csv"

# 写入CSV头(包含所有字段)
echo "shape_type,data_type,M,N,K,alpha,beta,transA,transB,grouped_gemm,batch_count,lda,stride_a,ldb,stride_b,ldc,stride_c,ldd,stride_d,a_type,b_type,c_type,d_type,compute_type,scaleA,scaleB,scaleC,scaleD,amaxD,activation_type,bias_vector,bias_type,hipblaslt-Gflops,us,status" > ${OUTPUT_CSV}

# 定义测试函数
run_test() {
    local shape_type=$1
    local data_type=$2
    local M=$3
    local N=$4
    local K=$5
    local alpha=${6}
    local beta=${7}
    
    echo "=========================================="
    echo "Running test: ${shape_type} - ${data_type} - M=${M} N=${N} K=${K}"
    
    # 构建命令
    local cmd="hipblaslt-bench --api_method c \
            -m ${M} -n ${N} -k ${K} \
            --alpha ${alpha} --beta ${beta} \
            --transA T --transB N \
            --batch_count 1 \
            --scaleA 2 --scaleB 2 \
            --bias_vector \
            --bias_source d \
            --a_type i8_r --lda ${M} \
            --b_type i8_r --lda ${K} \
            --c_type i32_r --lda ${M} \
            --d_type f16_r --lda ${M} \
            --scale_type f32_r \
            --bias_type f16_r \
            --compute_type i32_r \
            --print_kernel_info \
            --cold_iters 50 --iters 1000"
    
    # 运行hipblaslt-bench并捕获输出
    local output=$(eval ${cmd} 2>&1)
    
    # 检查是否有错误
    if echo "${output}" | grep -q "error:"; then
        echo "Test failed: ${shape_type} - ${data_type}"
        echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},T,N,0,1,0,0,0,0,0,0,0,0,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,FAILED,FAILED,ERROR" >> ${OUTPUT_CSV}
        return
    fi
    
    # 提取实际的数据行
    # 方法1: 查找以 transA 值开头(T或N)的行,且包含足够的逗号分隔符
    # 数据行格式: T,N,0,1,1024,1024,1024,1,1024,1048576,0,1024,1048576,1024,1048576,1024,1048576,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,138933,15.457
    local data_line=""
    
    # 先尝试查找以 T, 或 N, 开头的行(数据行的特征)
    data_line=$(echo "${output}" | grep -E "^(T|N),.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1)
    
    # 如果没找到,尝试查找包含 hipblaslt-Gflops 的行
    if [ -z "${data_line}" ]; then
        data_line=$(echo "${output}" | grep "hipblaslt-Gflops" | tail -1)
        # 如果找到了,取下一行(实际数据)
        if [ -n "${data_line}" ]; then
            data_line=$(echo "${output}" | grep -A1 "hipblaslt-Gflops" | tail -1)
        fi
    fi
    
    # 如果还没找到,尝试查找包含数字结尾且以 T或N 开头的行
    if [ -z "${data_line}" ]; then
        data_line=$(echo "${output}" | grep -E "^(T|N)," | grep -E "[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1)
    fi
    
    # 最后的尝试:查找包含逗号且不以 [ 或 - 或空格开头的行
    if [ -z "${data_line}" ]; then
        data_line=$(echo "${output}" | grep -v "^\[" | grep -v "^-" | grep -v "^[[:space:]]*$" | grep "," | grep -E "[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1)
    fi
    
    # 清理行首尾空格
    data_line=$(echo "${data_line}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
    
    if [ -n "${data_line}" ] && echo "${data_line}" | grep -q ","; then
        # 验证数据行格式:应该以 T, 或 N, 开头,并且包含至少30个逗号
        local comma_count=$(echo "${data_line}" | tr -cd ',' | wc -c)
        if [ ${comma_count} -lt 30 ]; then
            echo "Warning: Data line has only ${comma_count} commas, might be incomplete"
            echo "Data line: ${data_line}"
        fi
        
        # 解析数据行的所有字段
        # 字段索引(从1开始):
        # 1:transA, 2:transB, 3:grouped_gemm, 4:batch_count, 5:m, 6:n, 7:k, 
        # 8:alpha, 9:lda, 10:stride_a, 11:beta, 12:ldb, 13:stride_b, 
        # 14:ldc, 15:stride_c, 16:ldd, 17:stride_d, 
        # 18:a_type, 19:b_type, 20:c_type, 21:d_type, 22:compute_type,
        # 23:scaleA, 24:scaleB, 25:scaleC, 26:scaleD, 27:amaxD, 
        # 28:activation_type, 29:bias_vector, 30:bias_type,
        # 31:hipblaslt-Gflops, 32:us
        
        local transA=$(echo "${data_line}" | awk -F',' '{print $1}')
        local transB=$(echo "${data_line}" | awk -F',' '{print $2}')
        local grouped_gemm=$(echo "${data_line}" | awk -F',' '{print $3}')
        local batch_count=$(echo "${data_line}" | awk -F',' '{print $4}')
        
        # 根据实际字段数量动态提取
        local field_count=$(echo "${data_line}" | awk -F',' '{print NF}')
        
        if [ ${field_count} -ge 32 ]; then
            local lda=$(echo "${data_line}" | awk -F',' '{print $9}')
            local stride_a=$(echo "${data_line}" | awk -F',' '{print $10}')
            local ldb=$(echo "${data_line}" | awk -F',' '{print $12}')
            local stride_b=$(echo "${data_line}" | awk -F',' '{print $13}')
            local ldc=$(echo "${data_line}" | awk -F',' '{print $14}')
            local stride_c=$(echo "${data_line}" | awk -F',' '{print $15}')
            local ldd=$(echo "${data_line}" | awk -F',' '{print $16}')
            local stride_d=$(echo "${data_line}" | awk -F',' '{print $17}')
            local a_type=$(echo "${data_line}" | awk -F',' '{print $18}')
            local b_type=$(echo "${data_line}" | awk -F',' '{print $19}')
            local c_type=$(echo "${data_line}" | awk -F',' '{print $20}')
            local d_type=$(echo "${data_line}" | awk -F',' '{print $21}')
            local compute_type=$(echo "${data_line}" | awk -F',' '{print $22}')
            local scaleA=$(echo "${data_line}" | awk -F',' '{print $23}')
            local scaleB=$(echo "${data_line}" | awk -F',' '{print $24}')
            local scaleC=$(echo "${data_line}" | awk -F',' '{print $25}')
            local scaleD=$(echo "${data_line}" | awk -F',' '{print $26}')
            local amaxD=$(echo "${data_line}" | awk -F',' '{print $27}')
            local activation_type=$(echo "${data_line}" | awk -F',' '{print $28}')
            local bias_vector=$(echo "${data_line}" | awk -F',' '{print $29}')
            local bias_type=$(echo "${data_line}" | awk -F',' '{print $30}')
            local gflops=$(echo "${data_line}" | awk -F',' '{print $31}')
            local us=$(echo "${data_line}" | awk -F',' '{print $32}')
        else
            # 如果字段数不够,尝试从末尾提取Gflops和us
            local gflops=$(echo "${data_line}" | awk -F',' '{print $(NF-1)}')
            local us=$(echo "${data_line}" | awk -F',' '{print $NF}')
            # 设置默认值
            lda=0; stride_a=0; ldb=0; stride_b=0; ldc=0; stride_c=0; ldd=0; stride_d=0
            a_type="i8_r"; b_type="i8_r"; c_type="i32_r"; d_type="f16_r"
            compute_type="i32_r"; scaleA=2; scaleB=2; scaleC=0; scaleD=0
            amaxD=0; activation_type="none"; bias_vector=1; bias_type="f16_r"
        fi
        
        # 清理可能的前后空格
        gflops=$(echo "${gflops}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
        us=$(echo "${us}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
        
        echo "Result: ${gflops} GFLOPS, ${us} us"
        
        # 写入CSV
        echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},${transA},${transB},${grouped_gemm},${batch_count},${lda},${stride_a},${ldb},${stride_b},${ldc},${stride_c},${ldd},${stride_d},${a_type},${b_type},${c_type},${d_type},${compute_type},${scaleA},${scaleB},${scaleC},${scaleD},${amaxD},${activation_type},${bias_vector},${bias_type},${gflops},${us},SUCCESS" >> ${OUTPUT_CSV}
    else
        echo "Test failed to produce valid output: ${shape_type} - ${data_type}"
        echo "Raw output (last 20 lines):"
        echo "${output}" | tail -20
        echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},T,N,0,1,0,0,0,0,0,0,0,0,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,FAILED,FAILED,NO_OUTPUT" >> ${OUTPUT_CSV}
    fi
}

# 测试场景3a: M,K,N同时相等的情况
for size in 128 256 512 1024 2048 4096 8192; do
    run_test "square" "w8a8" ${size} ${size} ${size} 1.0 0.0
done

# 测试场景3b: 非4对齐的情况
for size in 4098 8190; do
    run_test "square" "w8a8" ${size} ${size} ${size} 1.0 0.0
done

# 测试场景3c: M=8192, K=768, N=8192
run_test "square" "w8a8" 8192 8192 768 1.0 0.0

echo "=========================================="
echo "All tests completed. Results saved to ${OUTPUT_CSV}"