Unverified Commit 0ca3b8e0 authored by Aleksandr Malyshev's avatar Aleksandr Malyshev Committed by GitHub
Browse files

[BUGFIX] Skip tokenization support for throughput benchmark (#12712)


Signed-off-by: default avatarroot <root@banff-cyxtera-s73-5.ctr.dcgpu>
Signed-off-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarroot <root@banff-cyxtera-s73-5.ctr.dcgpu>
Co-authored-by: default avatarAleksandr Malyshev <maleksan@amd.com>
parent cc102814
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import random import random
import time import time
from functools import cache from functools import cache
from typing import Any, Optional from typing import Any, Optional, Union
import torch import torch
import uvloop import uvloop
...@@ -20,7 +20,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, ...@@ -20,7 +20,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
...@@ -178,10 +178,13 @@ def run_vllm( ...@@ -178,10 +178,13 @@ def run_vllm(
"Please ensure that max_model_len is greater than the sum of" "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[TextPrompt] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data)) multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
...@@ -242,11 +245,14 @@ async def run_vllm_async( ...@@ -242,11 +245,14 @@ async def run_vllm_async(
" prompt_len and expected_output_len for all requests.") " prompt_len and expected_output_len for all requests.")
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[TextPrompt] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
lora_requests: list[Optional[LoRARequest]] = [] lora_requests: list[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
multi_modal_data=request.multi_modal_data)
if "prompt_token_ids" in request.prompt else \
TextPrompt(prompt=request.prompt, TextPrompt(prompt=request.prompt,
multi_modal_data=request.multi_modal_data)) multi_modal_data=request.multi_modal_data))
sampling_params.append( sampling_params.append(
...@@ -393,24 +399,29 @@ def main(args: argparse.Namespace): ...@@ -393,24 +399,29 @@ def main(args: argparse.Namespace):
random.randint(0, vocab_size - 1) random.randint(0, vocab_size - 1)
for _ in range(args.input_len) for _ in range(args.input_len)
] ]
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length. candidate_prompt = {"prompt_token_ids": candidate_ids}
for _ in range(5): # Max attempts to correct
candidate_prompt = request_tokenizer.decode(candidate_ids) if not args.skip_tokenizer_init:
tokenized_len = len(request_tokenizer.encode(candidate_prompt)) # As tokenizer may add additional tokens like BOS, we need
# to try different lengths to get the desired input length.
if tokenized_len == args.input_len: for _ in range(5): # Max attempts to correct
break candidate_prompt = request_tokenizer.decode(candidate_ids)
tokenized_len = len(
# Adjust length based on difference request_tokenizer.encode(candidate_prompt))
diff = args.input_len - tokenized_len
if diff > 0: if tokenized_len == args.input_len:
candidate_ids.extend([ break
random.randint(100, vocab_size - 100)
for _ in range(diff) # Adjust length based on difference
]) diff = args.input_len - tokenized_len
else: if diff > 0:
candidate_ids = candidate_ids[:diff] candidate_ids.extend([
random.randint(100, vocab_size - 100)
for _ in range(diff)
])
else:
candidate_ids = candidate_ids[:diff]
requests.append( requests.append(
SampleRequest(prompt=candidate_prompt, SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len, prompt_len=args.input_len,
......
...@@ -276,7 +276,9 @@ class EngineArgs: ...@@ -276,7 +276,9 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--skip-tokenizer-init', '--skip-tokenizer-init',
action='store_true', action='store_true',
help='Skip initialization of tokenizer and detokenizer.') help='Skip initialization of tokenizer and detokenizer. '
'Expects valid prompt_token_ids and None for prompt from '
'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=nullable_str, type=nullable_str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment