Unverified Commit 01d6d400 authored by Jiashi Li's avatar Jiashi Li Committed by GitHub
Browse files

Merge pull request #45 from yangsijia-serena/main

fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them
parents 6492cabb b6798030
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import math import math
import random import random
import flashinfer
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
import argparse
# pip install flashinfer-python # pip install flashinfer-python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache from flash_mla import flash_mla_with_kvcache, get_mla_metadata
import flashinfer
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() query = query.float()
...@@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
...@@ -501,7 +503,8 @@ def get_args(): ...@@ -501,7 +503,8 @@ def get_args():
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
with open("all_perf.csv", "w") as fout: benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n") fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
...@@ -509,6 +512,9 @@ if __name__ == "__main__": ...@@ -509,6 +512,9 @@ if __name__ == "__main__":
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
elif args.compare: elif args.compare:
compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one: elif args.one:
compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
\ No newline at end of file fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
\ No newline at end of file
import argparse
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
file_path = 'all_perf.csv'
def parse_args():
parser = argparse.ArgumentParser(description='Visualize benchmark results')
parser.add_argument('--file', type=str, default='all_perf.csv',
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
return parser.parse_args()
args = parse_args()
file_path = args.file
df = pd.read_csv(file_path) df = pd.read_csv(file_path)
...@@ -16,4 +26,4 @@ plt.xlabel('seqlen') ...@@ -16,4 +26,4 @@ plt.xlabel('seqlen')
plt.ylabel('bw (GB/s)') plt.ylabel('bw (GB/s)')
plt.legend() plt.legend()
plt.savefig('bandwidth_vs_seqlen.png') plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment