#!/bin/bash # ROCm BLAS 性能测试脚本 # 测试多种矩阵形状和数据类型的GEMM性能 # 设置使用的GPU设备 export HIP_VISIBLE_DEVICES=7 chmod +x ${ROCM_PATH}/bin/hipblaslt-bench HIPBLASLT_TUNING_OVERRIDE_FILE=hipblaslt.config # 输出文件 OUTPUT_CSV="hipblaslt_bench_results.csv" # 写入CSV头(包含所有字段) echo "shape_type,data_type,M,N,K,alpha,beta,transA,transB,grouped_gemm,batch_count,lda,stride_a,ldb,stride_b,ldc,stride_c,ldd,stride_d,a_type,b_type,c_type,d_type,compute_type,scaleA,scaleB,scaleC,scaleD,amaxD,activation_type,bias_vector,bias_type,hipblaslt-Gflops,us,status" > ${OUTPUT_CSV} # 定义测试函数 run_test() { local shape_type=$1 local data_type=$2 local M=$3 local N=$4 local K=$5 local alpha=${6} local beta=${7} echo "==========================================" echo "Running test: ${shape_type} - ${data_type} - M=${M} N=${N} K=${K}" # 构建命令 local cmd="hipblaslt-bench --api_method c \ -m ${M} -n ${N} -k ${K} \ --alpha ${alpha} --beta ${beta} \ --transA T --transB N \ --batch_count 1 \ --scaleA 2 --scaleB 2 \ --bias_vector \ --bias_source d \ --a_type i8_r --lda ${M} \ --b_type i8_r --lda ${K} \ --c_type i32_r --lda ${M} \ --d_type f16_r --lda ${M} \ --scale_type f32_r \ --bias_type f16_r \ --compute_type i32_r \ --print_kernel_info \ --cold_iters 50 --iters 1000" # 运行hipblaslt-bench并捕获输出 local output=$(eval ${cmd} 2>&1) # 检查是否有错误 if echo "${output}" | grep -q "error:"; then echo "Test failed: ${shape_type} - ${data_type}" echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},T,N,0,1,0,0,0,0,0,0,0,0,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,FAILED,FAILED,ERROR" >> ${OUTPUT_CSV} return fi # 提取实际的数据行 # 方法1: 查找以 transA 值开头(T或N)的行,且包含足够的逗号分隔符 # 数据行格式: T,N,0,1,1024,1024,1024,1,1024,1048576,0,1024,1048576,1024,1048576,1024,1048576,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,138933,15.457 local data_line="" # 先尝试查找以 T, 或 N, 开头的行(数据行的特征) data_line=$(echo "${output}" | grep -E "^(T|N),.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,.*,[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1) # 如果没找到,尝试查找包含 hipblaslt-Gflops 的行 if [ -z "${data_line}" ]; then data_line=$(echo "${output}" | grep "hipblaslt-Gflops" | tail -1) # 如果找到了,取下一行(实际数据) if [ -n "${data_line}" ]; then data_line=$(echo "${output}" | grep -A1 "hipblaslt-Gflops" | tail -1) fi fi # 如果还没找到,尝试查找包含数字结尾且以 T或N 开头的行 if [ -z "${data_line}" ]; then data_line=$(echo "${output}" | grep -E "^(T|N)," | grep -E "[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1) fi # 最后的尝试:查找包含逗号且不以 [ 或 - 或空格开头的行 if [ -z "${data_line}" ]; then data_line=$(echo "${output}" | grep -v "^\[" | grep -v "^-" | grep -v "^[[:space:]]*$" | grep "," | grep -E "[0-9]+\.[0-9]+,[0-9]+\.[0-9]+$" | tail -1) fi # 清理行首尾空格 data_line=$(echo "${data_line}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//') if [ -n "${data_line}" ] && echo "${data_line}" | grep -q ","; then # 验证数据行格式:应该以 T, 或 N, 开头,并且包含至少30个逗号 local comma_count=$(echo "${data_line}" | tr -cd ',' | wc -c) if [ ${comma_count} -lt 30 ]; then echo "Warning: Data line has only ${comma_count} commas, might be incomplete" echo "Data line: ${data_line}" fi # 解析数据行的所有字段 # 字段索引(从1开始): # 1:transA, 2:transB, 3:grouped_gemm, 4:batch_count, 5:m, 6:n, 7:k, # 8:alpha, 9:lda, 10:stride_a, 11:beta, 12:ldb, 13:stride_b, # 14:ldc, 15:stride_c, 16:ldd, 17:stride_d, # 18:a_type, 19:b_type, 20:c_type, 21:d_type, 22:compute_type, # 23:scaleA, 24:scaleB, 25:scaleC, 26:scaleD, 27:amaxD, # 28:activation_type, 29:bias_vector, 30:bias_type, # 31:hipblaslt-Gflops, 32:us local transA=$(echo "${data_line}" | awk -F',' '{print $1}') local transB=$(echo "${data_line}" | awk -F',' '{print $2}') local grouped_gemm=$(echo "${data_line}" | awk -F',' '{print $3}') local batch_count=$(echo "${data_line}" | awk -F',' '{print $4}') # 根据实际字段数量动态提取 local field_count=$(echo "${data_line}" | awk -F',' '{print NF}') if [ ${field_count} -ge 32 ]; then local lda=$(echo "${data_line}" | awk -F',' '{print $9}') local stride_a=$(echo "${data_line}" | awk -F',' '{print $10}') local ldb=$(echo "${data_line}" | awk -F',' '{print $12}') local stride_b=$(echo "${data_line}" | awk -F',' '{print $13}') local ldc=$(echo "${data_line}" | awk -F',' '{print $14}') local stride_c=$(echo "${data_line}" | awk -F',' '{print $15}') local ldd=$(echo "${data_line}" | awk -F',' '{print $16}') local stride_d=$(echo "${data_line}" | awk -F',' '{print $17}') local a_type=$(echo "${data_line}" | awk -F',' '{print $18}') local b_type=$(echo "${data_line}" | awk -F',' '{print $19}') local c_type=$(echo "${data_line}" | awk -F',' '{print $20}') local d_type=$(echo "${data_line}" | awk -F',' '{print $21}') local compute_type=$(echo "${data_line}" | awk -F',' '{print $22}') local scaleA=$(echo "${data_line}" | awk -F',' '{print $23}') local scaleB=$(echo "${data_line}" | awk -F',' '{print $24}') local scaleC=$(echo "${data_line}" | awk -F',' '{print $25}') local scaleD=$(echo "${data_line}" | awk -F',' '{print $26}') local amaxD=$(echo "${data_line}" | awk -F',' '{print $27}') local activation_type=$(echo "${data_line}" | awk -F',' '{print $28}') local bias_vector=$(echo "${data_line}" | awk -F',' '{print $29}') local bias_type=$(echo "${data_line}" | awk -F',' '{print $30}') local gflops=$(echo "${data_line}" | awk -F',' '{print $31}') local us=$(echo "${data_line}" | awk -F',' '{print $32}') else # 如果字段数不够,尝试从末尾提取Gflops和us local gflops=$(echo "${data_line}" | awk -F',' '{print $(NF-1)}') local us=$(echo "${data_line}" | awk -F',' '{print $NF}') # 设置默认值 lda=0; stride_a=0; ldb=0; stride_b=0; ldc=0; stride_c=0; ldd=0; stride_d=0 a_type="i8_r"; b_type="i8_r"; c_type="i32_r"; d_type="f16_r" compute_type="i32_r"; scaleA=2; scaleB=2; scaleC=0; scaleD=0 amaxD=0; activation_type="none"; bias_vector=1; bias_type="f16_r" fi # 清理可能的前后空格 gflops=$(echo "${gflops}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//') us=$(echo "${us}" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//') echo "Result: ${gflops} GFLOPS, ${us} us" # 写入CSV echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},${transA},${transB},${grouped_gemm},${batch_count},${lda},${stride_a},${ldb},${stride_b},${ldc},${stride_c},${ldd},${stride_d},${a_type},${b_type},${c_type},${d_type},${compute_type},${scaleA},${scaleB},${scaleC},${scaleD},${amaxD},${activation_type},${bias_vector},${bias_type},${gflops},${us},SUCCESS" >> ${OUTPUT_CSV} else echo "Test failed to produce valid output: ${shape_type} - ${data_type}" echo "Raw output (last 20 lines):" echo "${output}" | tail -20 echo "${shape_type},${data_type},${M},${N},${K},${alpha},${beta},T,N,0,1,0,0,0,0,0,0,0,0,i8_r,i8_r,i32_r,f16_r,i32_r,2,2,0,0,0,none,1,f16_r,FAILED,FAILED,NO_OUTPUT" >> ${OUTPUT_CSV} fi } # 测试场景3a: M,K,N同时相等的情况 for size in 128 256 512 1024 2048 4096 8192; do run_test "square" "w8a8" ${size} ${size} ${size} 1.0 0.0 done # 测试场景3b: 非4对齐的情况 for size in 4098 8190; do run_test "square" "w8a8" ${size} ${size} ${size} 1.0 0.0 done # 测试场景3c: M=8192, K=768, N=8192 run_test "square" "w8a8" 8192 8192 768 1.0 0.0 echo "==========================================" echo "All tests completed. Results saved to ${OUTPUT_CSV}"