#!/bin/bash
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export VLLM_PCIE_USE_CUSTOM_ALLREDUCE=1
export VLLM_USE_TRITON_PREFIX_FLASH_ATTN=1
export VLLM_NUMA_BIND=1
export VLLM_RANK0_NUMA=0
export VLLM_RANK1_NUMA=1
export VLLM_RANK2_NUMA=2
export VLLM_RANK3_NUMA=3
export VLLM_RANK4_NUMA=4
export VLLM_RANK5_NUMA=5
export VLLM_RANK6_NUMA=6
export VLLM_RANK7_NUMA=7

# 从环境变量读取参数
model_name=${MODEL_NAME}
model_path=${MODEL_PATH}
tp=${TP}
data_type=${DATA_TYPE}
batch_list=${BATCH_LIST}
prompt_pairs=${PROMPT_PAIRS}
port=${PORT}

# 生成结果文件名
result_file="/workspace/test/inference_outputs/results/${model_name}_tp${tp}.csv"
echo "tp,data_type,batch,prompt_tokens,completion_tokens,TOTAL_THROUGHPUT(toks/s),generate_throughput(toks/s),TTFT(ms),TPOT(ms),ITL(ms)" > "$result_file"

# 转换字符串为数组
IFS=' ' read -ra batches <<< "$batch_list"
IFS=',' read -ra pairs <<< "$prompt_pairs"
# 执行测试
for batch in "${batches[@]}"; do
    for pair in "${pairs[@]}"; do
        IFS=' ' read -r prompt_tokens completion_tokens <<< "$pair"
        log_file="/workspace/test/inference_outputs/logs/models/${model_name}_${tp}/batch_${batch}_prompt_${prompt_tokens}_completion_${completion_tokens}.log"
        mkdir -p "$(dirname "$log_file")"
        echo "Running: batch=$batch, prompt=$prompt_tokens, completion=$completion_tokens"
        python benchmark_serving.py \
            --backend openai \
            --port "$port" \
            --model "$model_path" \
            --trust-remote-code \
            --dataset-name random \
            --ignore-eos \
            --random-input-len "$prompt_tokens" \
            --random-output-len "$completion_tokens" \
            --num-prompts "$batch" \
            2>&1 | tee "$log_file"
        
        # 提取指标
        TOTAL_THROUGHPUT=$(grep "^Total Token" "$log_file" | awk '{print $5}')
        GEN_THROUGHPUT=$(grep "^Output token" "$log_file" | awk '{print $5}')
        TTFT=$(grep "^Mean TTFT" "$log_file" | awk '{print $4}')
        TPOT=$(grep "^Mean TPOT" "$log_file" | awk '{print $4}')
        ITL=$(grep "^Mean ITL" "$log_file" | awk '{print $4}')
        echo "$tp,$data_type,$batch,$prompt_tokens,$completion_tokens,$TOTAL_THROUGHPUT,$GEN_THROUGHPUT,$TTFT,$TPOT,$ITL" >> "$result_file"
    done
done
