"vscode:/vscode.git/clone" did not exist on "d1c6799b8870e513bf4f2305cbf6cda9fc3d773b"
Commit 7de30014 authored by zhuwenwen's avatar zhuwenwen
Browse files

update perf

parent 2cfcb974
...@@ -34,6 +34,7 @@ class RequestFuncInput: ...@@ -34,6 +34,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: Optional[str] = None language: Optional[str] = None
request_id: Optional[str] = None
@dataclass @dataclass
...@@ -71,6 +72,9 @@ async def async_request_tgi( ...@@ -71,6 +72,9 @@ async def async_request_tgi(
"inputs": request_func_input.prompt, "inputs": request_func_input.prompt,
"parameters": params, "parameters": params,
} }
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
...@@ -82,7 +86,9 @@ async def async_request_tgi( ...@@ -82,7 +86,9 @@ async def async_request_tgi(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
...@@ -145,6 +151,9 @@ async def async_request_trt_llm( ...@@ -145,6 +151,9 @@ async def async_request_trt_llm(
} }
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len payload["min_length"] = request_func_input.output_len
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -152,7 +161,9 @@ async def async_request_trt_llm( ...@@ -152,7 +161,9 @@ async def async_request_trt_llm(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
...@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( ...@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
"top_p": 1.0, "top_p": 1.0,
} }
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -283,6 +296,8 @@ async def async_request_openai_completions( ...@@ -283,6 +296,8 @@ async def async_request_openai_completions(
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions( ...@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -491,6 +508,8 @@ async def async_request_openai_audio( ...@@ -491,6 +508,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file # Send audio file
def to_bytes(y, sr): def to_bytes(y, sr):
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from io import BytesIO from io import BytesIO
...@@ -54,6 +55,7 @@ class SampleRequest: ...@@ -54,6 +55,7 @@ class SampleRequest:
expected_output_len: int expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): ...@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample( def sample(
self, tokenizer: PreTrainedTokenizerBase, num_requests: int self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
) -> list[SampleRequest]: ) -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
...@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): ...@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
...@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): ...@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests( def maybe_oversample_requests(
self, requests: list[SampleRequest], num_requests: int self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None: ) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
...@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): ...@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
Args: Args:
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests. requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, k=num_requests - len(requests)) additional = deepcopy(
random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", num_requests) logger.info("Oversampled requests to reach %d total samples.", num_requests)
...@@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: ...@@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]:
) )
def process_video(video: Any) -> Mapping[str, Any]:
"""
Process a single video input and return a multimedia content dictionary.
Supports the following input types:
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
containing raw video data.
2. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(video, dict) and "bytes" in video:
video_bytes = video["bytes"]
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
return {
"type": "video_url",
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
}
if isinstance(video, str):
video_url = (
video if video.startswith(("http://", "file://")) else f"file://{video}"
)
return {"type": "video_url", "video_url": {"url": video_url}}
raise ValueError(
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data) # Random Dataset Implementation (Synthetic Data)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Enforce range_ratio < 1 # Enforce range_ratio < 1
...@@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): ...@@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
) )
) )
return requests return requests
...@@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
samples: list = [] samples: list = []
ind = 0
for entry in self.data: for entry in self.data:
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
...@@ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len is not None, skip_min_output_len_check=output_len is not None,
): ):
continue continue
# TODO: Also support ShareGPT4Video.
if image_path := entry.get("image"): if image_path := entry.get("image"):
mm_content = process_image(image_path) mm_content = process_image(image_path)
elif video_path := entry.get("video"):
mm_content = process_video(video_path)
else: else:
mm_content = None mm_content = None
if enable_multimodal_chat: if enable_multimodal_chat:
...@@ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(samples, num_requests) ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset): ...@@ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False, skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
...@@ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset): ...@@ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)
) )
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
) )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False, return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
...@@ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines] prefix_lines = self.data[:num_prefix_lines]
samples = [] samples = []
ind = 0
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices( extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines self.data, k=num_input_lines - num_prefix_lines
...@@ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): ...@@ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset):
msg, add_generation_prompt=True, tokenize=False msg, add_generation_prompt=True, tokenize=False
) )
prompt_len = len(tokenizer(prompt_formatted).input_ids) prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len: if prompt_len <= input_len:
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt_formatted if return_prompt_formatted else prompt, prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
return samples return samples
...@@ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int, num_requests: int,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
samples = [] samples = []
...@@ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
request_id=request_id_prefix + str(i),
) )
) )
return samples return samples
...@@ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in filtered_data: for item in filtered_data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
...@@ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)
) )
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
) )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \ prompt = (
the code, do not include any explanation." f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
# apply template # apply template
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
...@@ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["turns"][0] prompt = item["turns"][0]
...@@ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)
) )
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
) )
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
...@@ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): ...@@ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): ...@@ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt, "zed-industries/zeta": _format_zeta_prompt,
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
**kwargs,
):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
for sample in self.data: for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample) sample = formatting_prompt_func(sample)
samples.append( samples.append(
SampleRequest( SampleRequest(
...@@ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids tokenizer(sample["expected_output"]).input_ids
), ),
request_id=request_id_prefix + str(i),
) )
) )
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
import librosa import librosa
...@@ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset):
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
skipped = 0 skipped = 0
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
...@@ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
if skipped: if skipped:
logger.warning( logger.warning(
"%d samples discarded from dataset due to" "%d samples discarded from dataset due to"
...@@ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
\ No newline at end of file
...@@ -489,8 +489,10 @@ class BenchmarkWorker: ...@@ -489,8 +489,10 @@ class BenchmarkWorker:
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
block_n = block_quant_shape[0] if block_quant_shape else None
block_k = block_quant_shape[1] if block_quant_shape else None
op_config = get_moe_configs( op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, use_nn_moe=nn_moe num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k, use_nn_moe=nn_moe
) )
if op_config is None: if op_config is None:
config = get_default_config( config = get_default_config(
...@@ -500,8 +502,8 @@ class BenchmarkWorker: ...@@ -500,8 +502,8 @@ class BenchmarkWorker:
hidden_size, hidden_size,
topk, topk,
dtype_str, dtype_str,
is_marlin=False, block_quant_shape,
use_nn_moe=nn_moe use_nn_moe=nn_moe,
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
......
...@@ -375,11 +375,12 @@ async def benchmark( ...@@ -375,11 +375,12 @@ async def benchmark(
rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
request.expected_output_len, request.expected_output_len,
request.multi_modal_data, request.multi_modal_data,
request.request_id,
) )
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
...@@ -397,6 +398,7 @@ async def benchmark( ...@@ -397,6 +398,7 @@ async def benchmark(
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body, extra_body=extra_body,
request_id=request_id,
) )
task = limited_request_func(request_func_input=request_func_input, pbar=pbar) task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
tasks.append(asyncio.create_task(task)) tasks.append(asyncio.create_task(task))
...@@ -665,6 +667,7 @@ def main(args: argparse.Namespace): ...@@ -665,6 +667,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
...@@ -678,6 +681,7 @@ def main(args: argparse.Namespace): ...@@ -678,6 +681,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=False, return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
) )
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
...@@ -690,6 +694,7 @@ def main(args: argparse.Namespace): ...@@ -690,6 +694,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=True, return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
...@@ -751,6 +756,7 @@ def main(args: argparse.Namespace): ...@@ -751,6 +756,7 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
) )
else: else:
...@@ -762,10 +768,15 @@ def main(args: argparse.Namespace): ...@@ -762,10 +768,15 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": lambda: BurstGPTDataset( "burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed, dataset_path=args.dataset_path
).sample(tokenizer=tokenizer, num_requests=args.num_prompts), ).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -773,6 +784,7 @@ def main(args: argparse.Namespace): ...@@ -773,6 +784,7 @@ def main(args: argparse.Namespace):
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
), ),
} }
...@@ -1118,6 +1130,13 @@ def create_argument_parser(): ...@@ -1118,6 +1130,13 @@ def create_argument_parser():
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve", "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
) )
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options") custom_group = parser.add_argument_group("custom dataset options")
......
...@@ -150,7 +150,6 @@ def run_vllm( ...@@ -150,7 +150,6 @@ def run_vllm(
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0].expected_output_len output_len = requests[0].expected_output_len
for request in requests: for request in requests:
...@@ -653,8 +652,8 @@ def validate_args(args): ...@@ -653,8 +652,8 @@ def validate_args(args):
# https://github.com/vllm-project/vllm/issues/16222 # https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1: if args.data_parallel_size > 1:
raise ValueError( raise ValueError(
"Data parallel is not supported in offline benchmark, \ "Data parallel is not supported in offline benchmark, "
please use benchmark serving instead" "please use benchmark serving instead"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment