Unverified Commit 98356735 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[misc] benchmark_throughput : Add LoRA (#11267)


Signed-off-by: default avatarVarun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varun@neuralmagic.com>
parent f26c4aee
...@@ -4,7 +4,8 @@ import dataclasses ...@@ -4,7 +4,8 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import List, Optional from functools import cache
from typing import Dict, List, Optional, Tuple
import torch import torch
import uvloop import uvloop
...@@ -17,8 +18,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs ...@@ -17,8 +18,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.utils import FlexibleArgumentParser, merge_async_iterators
...@@ -28,15 +32,17 @@ class SampleRequest: ...@@ -28,15 +32,17 @@ class SampleRequest:
Attributes: Attributes:
prompt: The input text prompt for the model. prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens. prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens. expected_output_len: The expected length of the output in tokens.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
lora_request: Optional LoRARequest specifying the LoRA to use.
""" """
prompt: str prompt: str
prompt_len: int prompt_len: int
expected_output_len: int expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None multi_modal_data: Optional[MultiModalDataDict] = None
lora_request: Optional[LoRARequest] = None
def _get_prompt_for_image_model(question: str, *, model: str) -> str: def _get_prompt_for_image_model(question: str, *, model: str) -> str:
...@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: ...@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}") raise ValueError(f"Unsupported model {model}")
@cache
def lora_path_on_disk(lora_path: str) -> str:
return get_adapter_absolute_path(lora_path)
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
def get_random_lora_request(
args: argparse.Namespace
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
global lora_tokenizer_cache
lora_id = random.randint(1, args.max_loras)
lora_request = LoRARequest(lora_name=str(lora_id),
lora_int_id=lora_id,
lora_path=lora_path_on_disk(args.lora_path))
if lora_id not in lora_tokenizer_cache:
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
return lora_request, lora_tokenizer_cache[lora_id]
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]: args: argparse.Namespace) -> List[SampleRequest]:
dataset_path: str = args.dataset dataset_path: str = args.dataset
num_requests: int = args.num_prompts num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len fixed_output_len: Optional[int] = args.output_len
...@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = [] filtered_dataset: List[SampleRequest] = []
for data in dataset: for data in tqdm(dataset,
total=len(filtered_dataset),
desc="sampling requests"):
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
...@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue continue
prompt = _get_prompt_for_image_model(question=prompt, model=model) prompt = _get_prompt_for_image_model(question=prompt, model=model)
request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompt_token_ids = tokenizer(prompt).input_ids prompt_token_ids = request_tokenizer(prompt).input_ids
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len ) if fixed_output_len is None else fixed_output_len
...@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt, SampleRequest(prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=multi_modal_data)) multi_modal_data=multi_modal_data,
lora_request=lora_request))
return filtered_dataset return filtered_dataset
...@@ -146,14 +184,21 @@ def run_vllm( ...@@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
)) ))
lora_requests: Optional[List[LoRARequest]] = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False use_beam_search = False
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests] prompts = [request.prompt for request in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0][2] output_len = requests[0][2]
...@@ -185,6 +230,7 @@ async def run_vllm_async( ...@@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[TextPrompt] = [] prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
lora_requests: List[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
...@@ -197,11 +243,16 @@ async def run_vllm_async( ...@@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
)) ))
lora_requests.append(request.lora_request)
generators = [] generators = []
start = time.perf_counter() start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): for i, (prompt, sp,
generator = llm.generate(prompt, sp, request_id=f"test{i}") lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt,
sp,
lora_request=lr,
request_id=f"test{i}")
generators.append(generator) generators.append(generator)
all_gens = merge_async_iterators(*generators) all_gens = merge_async_iterators(*generators)
async for i, res in all_gens: async for i, res in all_gens:
...@@ -297,6 +348,14 @@ def main(args: argparse.Namespace): ...@@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
requests = [] requests = []
for _ in range(args.num_prompts): for _ in range(args.num_prompts):
request_tokenizer = tokenizer
lora_request: Optional[LoRARequest] = None
if args.enable_lora:
lora_request, lora_tokenizer = get_random_lora_request(args)
if lora_tokenizer:
request_tokenizer = lora_tokenizer
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
candidate_ids = [ candidate_ids = [
random.randint(0, vocab_size - 1) random.randint(0, vocab_size - 1)
...@@ -305,8 +364,8 @@ def main(args: argparse.Namespace): ...@@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try # As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length. # different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct for _ in range(5): # Max attempts to correct
candidate_prompt = tokenizer.decode(candidate_ids) candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(tokenizer.encode(candidate_prompt)) tokenized_len = len(request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len: if tokenized_len == args.input_len:
break break
...@@ -323,7 +382,8 @@ def main(args: argparse.Namespace): ...@@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append( requests.append(
SampleRequest(prompt=candidate_prompt, SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len, prompt_len=args.input_len,
expected_output_len=args.output_len)) expected_output_len=args.output_len,
lora_request=lora_request))
else: else:
requests = sample_requests(tokenizer, args) requests = sample_requests(tokenizer, args)
...@@ -422,6 +482,14 @@ if __name__ == "__main__": ...@@ -422,6 +482,14 @@ if __name__ == "__main__":
action='store_true', action='store_true',
default=False, default=False,
help="Disable decoupled async engine frontend.") help="Disable decoupled async engine frontend.")
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
...@@ -431,6 +499,8 @@ if __name__ == "__main__": ...@@ -431,6 +499,8 @@ if __name__ == "__main__":
assert args.output_len is not None assert args.output_len is not None
else: else:
assert args.input_len is None assert args.input_len is None
if args.enable_lora:
assert args.lora_path is not None
if args.backend == "vllm": if args.backend == "vllm":
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
...@@ -440,6 +510,9 @@ if __name__ == "__main__": ...@@ -440,6 +510,9 @@ if __name__ == "__main__":
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
elif args.backend == "mii": elif args.backend == "mii":
if args.dtype != "auto": if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.") raise ValueError("dtype must be auto for MII backend.")
...@@ -452,4 +525,7 @@ if __name__ == "__main__": ...@@ -452,4 +525,7 @@ if __name__ == "__main__":
if args.tokenizer != args.model: if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII " raise ValueError("Tokenizer must be the same as the model for MII "
"backend.") "backend.")
if args.enable_lora is not None:
raise ValueError("LoRA benchmarking is only supported for vLLM"
" backend")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment