Commit d7d0889d authored by PanZezhong's avatar PanZezhong
Browse files

issue/80 修复attention prefill计时方式,重构目录

parent 469f2884
...@@ -175,6 +175,13 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -175,6 +175,13 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
torch.cuda.synchronize() torch.cuda.synchronize()
for _ in range(WARMUPS): for _ in range(WARMUPS):
for i, req in enumerate(req_list):
# ----------------------------------------- #
# 恢复 kv chche的长度
# ----------------------------------------- #
origin_len = test_cases["pastlens"][i]
req["past_key_values"].crop(origin_len)
for req in req_list: for req in req_list:
# ----------------------------------------- # # ----------------------------------------- #
# 获得每个req的数据 # 获得每个req的数据
...@@ -216,9 +223,13 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -216,9 +223,13 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
origin_len = test_cases["pastlens"][i] origin_len = test_cases["pastlens"][i]
req["past_key_values"].crop(origin_len) req["past_key_values"].crop(origin_len)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() # ----------------------------------------- #
# 重要:每个req都按整个batch的起始时间计算
# ----------------------------------------- #
start_time = time.time()
for i, req in enumerate(req_list):
# ----------------------------------------- # # ----------------------------------------- #
# 获得每个req的数据 # 获得每个req的数据
# ----------------------------------------- # # ----------------------------------------- #
...@@ -252,14 +263,15 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas ...@@ -252,14 +263,15 @@ def benchmark_Qwen3attention_prefill_torch(model, rotary_emb, req_list, test_cas
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
time_consuming += (end_time - start_time) * 1000 # 记录每个req从进入所有req进入推理到自己结束的时间
time_consuming += end_time - start_time
out_token_count = RUNS * len(req_list) out_token_count = RUNS * len(req_list)
latency = time_consuming / out_token_count latency = time_consuming * 1000 / out_token_count
print( print(
f"\t WARMUPS={WARMUPS} RUNS={RUNS}, Attention Torch, average latency: {round(latency, 2)} ms\n" f"\t WARMUPS={WARMUPS} RUNS={RUNS}, Attention Torch, average TTFT: {round(latency, 2)} ms\n"
) )
return req_out_list return req_out_list
...@@ -390,7 +402,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case ...@@ -390,7 +402,7 @@ def benchmark_Qwen3attention_decode_torch(model, rotary_emb, req_list, test_case
throughput = out_token_count / time_consuming throughput = out_token_count / time_consuming
print( print(
f"\t WARMUPS={WARMUPS} RUNS={RUNS} Attention Torch average throughput: {round(throughput, 2)} /s \n" f"\t WARMUPS={WARMUPS} RUNS={RUNS}, Attention Torch, average throughput: {round(throughput, 2)} tok/s \n"
) )
return req_out_list return req_out_list
......
...@@ -85,9 +85,9 @@ def generate_moe_input_torch(testcase, dtype=torch.bfloat16): ...@@ -85,9 +85,9 @@ def generate_moe_input_torch(testcase, dtype=torch.bfloat16):
return input_tensor return input_tensor
def benchmark_moe_torch(moe, input_host, device, dtype): def benchmark_moe_torch(moe, testcase, device, dtype):
"""""" """"""
input_host = generate_moe_input_torch(testcase, dtype=dtype)
input_device = input_host.to(device=device) input_device = input_host.to(device=device)
output_device, _ = moe(input_device) output_device, _ = moe(input_device)
...@@ -103,7 +103,11 @@ def benchmark_moe_torch(moe, input_host, device, dtype): ...@@ -103,7 +103,11 @@ def benchmark_moe_torch(moe, input_host, device, dtype):
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f" MoE Torch average latency: {(end_time - start_time) * 1000 / RUNS} ms") total_time = end_time - start_time
total_tokens = sum(testcase["seqlens"]) * RUNS
print(
f"\t WARMUPS={WARMUPS} RUNS={RUNS}, MoE Torch average latency: {round(total_time * 1000 / RUNS, 2)} ms throughput: {round(total_tokens / total_time, 2)} tok/s"
)
return output_host return output_host
...@@ -141,15 +145,16 @@ if __name__ == "__main__": ...@@ -141,15 +145,16 @@ if __name__ == "__main__":
print("Test Qwen3 MoE") print("Test Qwen3 MoE")
print("*" * 130) print("*" * 130)
print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}") print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}")
output_prefill = benchmark_moe_torch(
input_prefill = generate_moe_input_torch(PREFILL_TESTCASES) moe, PREFILL_TESTCASES, device=device, dtype=dtype
output_prefill = benchmark_moe_torch(moe, input_prefill, device=device, dtype=dtype) )
print("\n") print("\n")
print("-" * 130) print("-" * 130)
print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}") print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}")
input_decode = generate_moe_input_torch(DECODE_TESTCASES) output_decode = benchmark_moe_torch(
output_decode = benchmark_moe_torch(moe, input_decode, device=device, dtype=dtype) moe, DECODE_TESTCASES, device=device, dtype=dtype
)
# clean up device memory # clean up device memory
del moe del moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment