Unverified Commit 57201a6a authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

Fix rotary embedding benchmark script (#28323)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent f2d9ad06
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import accumulate import itertools
import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
batch_size_range = [2**i for i in range(0, 8, 2)]
seq_len_range = [2**i for i in range(6, 10, 1)]
num_heads_range = [32, 48]
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool, def get_benchmark(head_size, rotary_dim, is_neox_style, device):
batch_size: int, @triton.testing.perf_report(
seq_len: int, triton.testing.Benchmark(
num_heads: int, x_names=["batch_size", "seq_len", "num_heads"],
head_size: int, x_vals=[list(_) for _ in configs],
rotary_dim: int | None, line_arg="provider",
dtype: torch.dtype, line_vals=["torch", "flashinfer", "vllm"],
seed: int, line_names=["PyTorch", "FlashInfer", "vLLM"],
device: str, styles=[("blue", "-"), ("green", "-"), ("red", "-")],
max_position: int = 8192, ylabel="us",
base: float = 10000, plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
) -> None: args={},
current_platform.seed_everything(seed) )
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
) )
# non-batched RoPE takes only one scaling factor, we create multiple def benchmark(batch_size, seq_len, num_heads, provider):
# instances to simulate the same behavior dtype = torch.bfloat16
non_batched_ropes: list[RotaryEmbedding] = [] max_position = 8192
for scaling_factor in scaling_factors: base = 10000
non_batched_ropes.append( rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
get_rope( rope = rope.to(dtype=dtype, device=device)
head_size, cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
rotary_dim,
max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
base, query = torch.randn(
is_neox_style, (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
{"rope_type": "linear", "factor": (scaling_factor,)},
)
) )
key = torch.randn_like(query)
positions = torch.randint(0, max_position, (batch_size, seq_len)) quantiles = [0.5, 0.2, 0.8]
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache if provider == "torch":
# together and each query needs to find the right kv cache of its type ms, min_ms, max_ms = triton.testing.do_bench(
offset_map = torch.tensor( lambda: rope.forward_native(positions, query.clone(), key.clone()),
list( quantiles=quantiles,
accumulate(
[0]
+ [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
]
) )
) elif provider == "flashinfer":
) ms, min_ms, max_ms = triton.testing.do_bench(
query_types = torch.randint( lambda: torch.ops.vllm.flashinfer_rotary_embedding(
0, len(scaling_factors), (batch_size, seq_len), device=device positions,
) query.clone(),
# map query types to offsets key.clone(),
query_offsets = offset_map[query_types] head_size,
# the kernel takes flattened offsets cos_sin_cache,
flatten_offsets = query_offsets.flatten() is_neox_style,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# batched queries of the same type together for non-batched RoPE return benchmark
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -116,17 +95,12 @@ if __name__ == "__main__": ...@@ -116,17 +95,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
) )
parser.add_argument("--save-path", type=str, default="./configs/rope/")
args = parser.parse_args() args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora( # Get the benchmark function
is_neox_style=args.is_neox_style, benchmark = get_benchmark(
batch_size=args.batch_size, args.head_size, args.rotary_dim, args.is_neox_style, args.device
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
) )
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment