run_rocbals_bench.sh 5.72 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/bin/bash

# ROCm BLAS 性能测试脚本
# 测试多种矩阵形状和数据类型的GEMM性能

# 输出文件
export ROCBLAS_TENSILE_LIBPATH=./library_gpu5
export HIP_VISIBLE_DEVICES=3

chmod +x ${ROCM_PATH}/bin/rocblas-bench
OUTPUT_CSV="rocblas_bench_results_transB.csv"

# 写入CSV头
echo "shape_type,data_type,M,N,K,alpha,beta,transA,transB,a_type,b_type,c_type,d_type,compute_type,rocblas-Gflops,us" > ${OUTPUT_CSV}



# 定义测试函数
run_test() {
    local shape_type=$1
    local data_type=$2
    local M=$3
    local N=$4
    local K=$5
    local a_type=$6
    local b_type=$7
    local c_type=$8
    local d_type=$9
    local compute_type=${10}
    local alpha=${11}
    local beta=${12}
    local math_mode=${13}
    
    echo "Running test: ${shape_type} - ${data_type} - M=${M} N=${N} K=${K}"
    
    # 构建命令
    local cmd="rocblas-bench -f gemm_ex \
        -m ${M} -n ${N} -k ${K} \
        --alpha ${alpha} \
        --transposeA N --transposeB T \
        --a_type ${a_type} --lda ${M} \
        --b_type ${b_type} --ldb ${K} \
        --beta ${beta} \
        --c_type ${c_type} --ldc ${M} \
        --d_type ${d_type} --ldd ${M} \
        --compute_type ${compute_type} \
        --cold_iters 50 --iters 1000"
    
    # 如果指定了math_mode,则添加
    if [ -n "${math_mode}" ]; then
        cmd="${cmd} --math_mode ${math_mode}"
    fi
    
    # 运行rocblas-bench并捕获输出
    local output=$(eval ${cmd} 2>/dev/null | tail -1)
    
    # 解析输出
    if [ -n "${output}" ]; then
        # 提取rocblas-Gflops和us列
        local gflops=$(echo ${output} | awk '{print $(NF-1)}')
        local us=$(echo ${output} | awk '{print $NF}')
        echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},N,N,${a_type},${b_type},${c_type},${d_type},${compute_type},${gflops},${us}" >> ${OUTPUT_CSV}
    else
        echo "Test failed or not supported: ${shape_type} - ${data_type}"
        echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},N,N,${a_type},${b_type},${c_type},${d_type},${compute_type},FAILED,FAILED" >> ${OUTPUT_CSV}
    fi
}

# 测试场景3a: M,K,N同时相等的情况
for size in 128 256 512 1024 2048 4096 8192; do
    # fp64
    run_test "square" "fp64" ${size} ${size} ${size} "f64_r" "f64_r" "f64_r" "f64_r" "f64_r" 1.0 0.0 ""
    # fp32
    run_test "square" "fp32" ${size} ${size} ${size} "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # fp16
    run_test "square" "fp16" ${size} ${size} ${size} "f16_r" "f16_r" "f16_r" "f16_r" "f32_r" 1.0 0.0 ""
    # bf16
    run_test "square" "bf16" ${size} ${size} ${size} "bf16_r" "bf16_r" "bf16_r" "bf16_r" "f32_r" 1.0 0.0 ""
    # tf32 - 使用f32_r类型和math_mode=1
    run_test "square" "tf32" ${size} ${size} ${size} "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 "1"
done

# 测试场景3b: 非4对齐的情况
for size in 4098 8190; do
    # fp64
    run_test "non_aligned" "fp64" ${size} ${size} ${size} "f64_r" "f64_r" "f64_r" "f64_r" "f64_r" 1.0 0.0 ""
    # fp32
    run_test "non_aligned" "fp32" ${size} ${size} ${size} "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # fp16
    run_test "non_aligned" "fp16" ${size} ${size} ${size} "f16_r" "f16_r" "f16_r" "f16_r" "f32_r" 1.0 0.0 ""
    # bf16
    run_test "non_aligned" "bf16" ${size} ${size} ${size} "bf16_r" "bf16_r" "bf16_r" "bf16_r" "f32_r" 1.0 0.0 ""
    # tf32 - 使用f32_r类型和math_mode=1
    run_test "non_aligned" "tf32" ${size} ${size} ${size} "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 "1"
done

# 测试场景3c: M=8192, K=768, N=8192
run_test "special" "fp64" 8192 8192 768 "f64_r" "f64_r" "f64_r" "f64_r" "f64_r" 1.0 0.0 ""
run_test "special" "fp32" 8192 8192 768 "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
run_test "special" "fp16" 8192 8192 768 "f16_r" "f16_r" "f16_r" "f16_r" "f32_r" 1.0 0.0 ""
run_test "special" "bf16" 8192 8192 768 "bf16_r" "bf16_r" "bf16_r" "bf16_r" "f32_r" 1.0 0.0 ""
run_test "special" "tf32" 8192 8192 768 "f32_r" "f32_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 "1"

# 混合精度测试
for size in 128 256 512 1024 2048 4096 8192; do
    # fp16_fp32
    run_test "mixed" "fp16_fp32" ${size} ${size} ${size} "f16_r" "f16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # bf16_fp32
    run_test "mixed" "bf16_fp32" ${size} ${size} ${size} "bf16_r" "bf16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # int8_int32
    run_test "mixed" "int8_int32" ${size} ${size} ${size} "i8_r" "i8_r" "i32_r" "i32_r" "i32_r" 1.0 0.0 ""
done

for size in 4098 8190; do
    # fp16_fp32
    run_test "mixed" "fp16_fp32" ${size} ${size} ${size} "f16_r" "f16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # bf16_fp32
    run_test "mixed" "bf16_fp32" ${size} ${size} ${size} "bf16_r" "bf16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
    # int8_int32
    run_test "mixed" "int8_int32" ${size} ${size} ${size} "i8_r" "i8_r" "i32_r" "i32_r" "i32_r" 1.0 0.0 ""
done


# 场景5b: A,B: fp16, C,D: fp32
run_test "mixed" "fp16_fp32" 4096 4096 4096 "f16_r" "f16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
run_test "mixed" "fp16_fp32" 8192 8192 8192 "f16_r" "f16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
run_test "mixed" "fp16_fp32" 8192 8192 768 "f16_r" "f16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""

# 场景5c: A,B: bf16, C,D: fp32
run_test "mixed" "bf16_fp32" 4096 4096 4096 "bf16_r" "bf16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
run_test "mixed" "bf16_fp32" 8192 8192 8192 "bf16_r" "bf16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""
run_test "mixed" "bf16_fp32" 8192 8192 768 "bf16_r" "bf16_r" "f32_r" "f32_r" "f32_r" 1.0 0.0 ""

# 场景5e: A,B: int8, C,D: int32
run_test "mixed" "int8_int32" 4096 4096 4096 "i8_r" "i8_r" "i32_r" "i32_r" "i32_r" 1.0 0.0 ""
run_test "mixed" "int8_int32" 8192 8192 8192 "i8_r" "i8_r" "i32_r" "i32_r" "i32_r" 1.0 0.0 ""
run_test "mixed" "int8_int32" 8192 8192 768 "i8_r" "i8_r" "i32_r" "i32_r" "i32_r" 1.0 0.0 ""