Unverified Commit e34cf6ad authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix bench script making input data on L2 cache (#7739)

parent 62222bd2
...@@ -205,9 +205,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): ...@@ -205,9 +205,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "triton": if provider == "triton":
fn = lambda: triton_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype) fn = lambda: triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype) fn = lambda: sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment