Unverified Commit 31436e8b authored by hustxiayang's avatar hustxiayang Committed by GitHub
Browse files

[Misc] Add request_id into benchmark_serve.py (#23065)


Signed-off-by: default avataryangxia <yangxiast@gmail.com>
parent 4efd43e9
...@@ -34,6 +34,7 @@ class RequestFuncInput: ...@@ -34,6 +34,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: Optional[str] = None language: Optional[str] = None
request_id: Optional[str] = None
@dataclass @dataclass
...@@ -71,6 +72,9 @@ async def async_request_tgi( ...@@ -71,6 +72,9 @@ async def async_request_tgi(
"inputs": request_func_input.prompt, "inputs": request_func_input.prompt,
"parameters": params, "parameters": params,
} }
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
...@@ -82,7 +86,9 @@ async def async_request_tgi( ...@@ -82,7 +86,9 @@ async def async_request_tgi(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
...@@ -145,6 +151,9 @@ async def async_request_trt_llm( ...@@ -145,6 +151,9 @@ async def async_request_trt_llm(
} }
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len payload["min_length"] = request_func_input.output_len
headers = None
if request_func_input.request_id:
headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -152,7 +161,9 @@ async def async_request_trt_llm( ...@@ -152,7 +161,9 @@ async def async_request_trt_llm(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
...@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii( ...@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
"top_p": 1.0, "top_p": 1.0,
} }
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -283,6 +296,8 @@ async def async_request_openai_completions( ...@@ -283,6 +296,8 @@ async def async_request_openai_completions(
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions( ...@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -491,6 +508,8 @@ async def async_request_openai_audio( ...@@ -491,6 +508,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file # Send audio file
def to_bytes(y, sr): def to_bytes(y, sr):
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from io import BytesIO from io import BytesIO
...@@ -54,6 +55,7 @@ class SampleRequest: ...@@ -54,6 +55,7 @@ class SampleRequest:
expected_output_len: int expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC): ...@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample( def sample(
self, tokenizer: PreTrainedTokenizerBase, num_requests: int self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
) -> list[SampleRequest]: ) -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
...@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC): ...@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
...@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC): ...@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests( def maybe_oversample_requests(
self, requests: list[SampleRequest], num_requests: int self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None: ) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
...@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC): ...@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
Args: Args:
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests. requests.
num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, k=num_requests - len(requests)) additional = deepcopy(
random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", num_requests) logger.info("Oversampled requests to reach %d total samples.", num_requests)
...@@ -303,6 +319,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -303,6 +319,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Enforce range_ratio < 1 # Enforce range_ratio < 1
...@@ -363,8 +380,10 @@ class RandomDataset(BenchmarkDataset): ...@@ -363,8 +380,10 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
) )
) )
return requests return requests
...@@ -406,9 +425,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -406,9 +425,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
samples: list = [] samples: list = []
ind = 0
for entry in self.data: for entry in self.data:
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
...@@ -444,9 +465,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -444,9 +465,11 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(samples, num_requests) ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -512,10 +535,11 @@ class CustomDataset(BenchmarkDataset): ...@@ -512,10 +535,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False, skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
...@@ -534,9 +558,12 @@ class CustomDataset(BenchmarkDataset): ...@@ -534,9 +558,12 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -578,6 +605,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -578,6 +605,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False, return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
...@@ -603,6 +631,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -603,6 +631,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines] prefix_lines = self.data[:num_prefix_lines]
samples = [] samples = []
ind = 0
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices( extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines self.data, k=num_input_lines - num_prefix_lines
...@@ -613,14 +642,17 @@ class SonnetDataset(BenchmarkDataset): ...@@ -613,14 +642,17 @@ class SonnetDataset(BenchmarkDataset):
msg, add_generation_prompt=True, tokenize=False msg, add_generation_prompt=True, tokenize=False
) )
prompt_len = len(tokenizer(prompt_formatted).input_ids) prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len: if prompt_len <= input_len:
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt_formatted if return_prompt_formatted else prompt, prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
return samples return samples
...@@ -672,6 +704,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -672,6 +704,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int, num_requests: int,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
samples = [] samples = []
...@@ -693,6 +726,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -693,6 +726,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
request_id=request_id_prefix + str(i),
) )
) )
return samples return samples
...@@ -752,12 +786,14 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -752,12 +786,14 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in filtered_data: for item in filtered_data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -785,9 +821,13 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -785,9 +821,13 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -814,11 +854,12 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -814,11 +854,12 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
...@@ -838,9 +879,12 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -838,9 +879,12 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -870,11 +914,12 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -870,11 +914,12 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \ prompt = f"{item['input']}\n\n{item['instruction']} Just output \
...@@ -892,9 +937,12 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -892,9 +937,12 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -924,12 +972,13 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -924,12 +972,13 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["turns"][0] prompt = item["turns"][0]
...@@ -947,9 +996,12 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -947,9 +996,12 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -974,10 +1026,12 @@ class AIMODataset(HuggingFaceDataset): ...@@ -974,10 +1026,12 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -1000,9 +1054,13 @@ class AIMODataset(HuggingFaceDataset): ...@@ -1000,9 +1054,13 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind),
) )
) )
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -1072,12 +1130,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1072,12 +1130,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt, "zed-industries/zeta": _format_zeta_prompt,
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
**kwargs,
):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
for sample in self.data: for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample) sample = formatting_prompt_func(sample)
samples.append( samples.append(
SampleRequest( SampleRequest(
...@@ -1086,11 +1150,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1086,11 +1150,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids tokenizer(sample["expected_output"]).input_ids
), ),
request_id=request_id_prefix + str(i),
) )
) )
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -1139,6 +1204,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1139,6 +1204,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
import librosa import librosa
...@@ -1148,6 +1214,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1148,6 +1214,7 @@ class ASRDataset(HuggingFaceDataset):
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
skipped = 0 skipped = 0
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
...@@ -1166,8 +1233,10 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1166,8 +1233,10 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
if skipped: if skipped:
logger.warning( logger.warning(
"%d samples discarded from dataset due to" "%d samples discarded from dataset due to"
...@@ -1175,5 +1244,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1175,5 +1244,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix
)
return sampled_requests return sampled_requests
...@@ -375,11 +375,12 @@ async def benchmark( ...@@ -375,11 +375,12 @@ async def benchmark(
rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
request.expected_output_len, request.expected_output_len,
request.multi_modal_data, request.multi_modal_data,
request.request_id,
) )
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
...@@ -397,6 +398,7 @@ async def benchmark( ...@@ -397,6 +398,7 @@ async def benchmark(
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body, extra_body=extra_body,
request_id=request_id,
) )
task = limited_request_func(request_func_input=request_func_input, pbar=pbar) task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
tasks.append(asyncio.create_task(task)) tasks.append(asyncio.create_task(task))
...@@ -665,6 +667,7 @@ def main(args: argparse.Namespace): ...@@ -665,6 +667,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
...@@ -678,6 +681,7 @@ def main(args: argparse.Namespace): ...@@ -678,6 +681,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=False, return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
) )
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
...@@ -690,6 +694,7 @@ def main(args: argparse.Namespace): ...@@ -690,6 +694,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=True, return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
...@@ -751,6 +756,7 @@ def main(args: argparse.Namespace): ...@@ -751,6 +756,7 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
) )
else: else:
...@@ -762,10 +768,15 @@ def main(args: argparse.Namespace): ...@@ -762,10 +768,15 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": lambda: BurstGPTDataset( "burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path random_seed=args.seed, dataset_path=args.dataset_path
).sample(tokenizer=tokenizer, num_requests=args.num_prompts), ).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,
),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
...@@ -773,6 +784,7 @@ def main(args: argparse.Namespace): ...@@ -773,6 +784,7 @@ def main(args: argparse.Namespace):
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
), ),
} }
...@@ -1118,6 +1130,13 @@ def create_argument_parser(): ...@@ -1118,6 +1130,13 @@ def create_argument_parser():
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve", "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
) )
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
# group for dataset specific arguments # group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options") custom_group = parser.add_argument_group("custom dataset options")
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from io import BytesIO from io import BytesIO
...@@ -76,6 +77,7 @@ class SampleRequest: ...@@ -76,6 +77,7 @@ class SampleRequest:
Union[MultiModalDataDict, dict, list[dict]] Union[MultiModalDataDict, dict, list[dict]]
] = None ] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
request_id: Optional[str] = None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -183,7 +185,8 @@ class BenchmarkDataset(ABC): ...@@ -183,7 +185,8 @@ class BenchmarkDataset(ABC):
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(self, tokenizer: PreTrainedTokenizerBase,
num_requests: int) -> list[SampleRequest]: num_requests: int,
request_id_prefix: str = "") -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
...@@ -194,6 +197,8 @@ class BenchmarkDataset(ABC): ...@@ -194,6 +197,8 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text. for processing the dataset's text.
num_requests (int): The number of sample requests to generate. num_requests (int): The number of sample requests to generate.
request_id_prefix (str) The prefix of request_id.
Returns: Returns:
list[SampleRequest]: A list of sample requests generated from the list[SampleRequest]: A list of sample requests generated from the
...@@ -201,8 +206,12 @@ class BenchmarkDataset(ABC): ...@@ -201,8 +206,12 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(
num_requests: int) -> None: self,
requests: list[SampleRequest],
num_requests: int,
request_id_prefix: str = "",
) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
...@@ -211,11 +220,17 @@ class BenchmarkDataset(ABC): ...@@ -211,11 +220,17 @@ class BenchmarkDataset(ABC):
requests (List[SampleRequest]): The current list of sampled requests (List[SampleRequest]): The current list of sampled
requests. requests.
num_requests (int): The target number of requests. num_requests (int): The target number of requests.
request_id_prefix (str) The prefix of the request ids.
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = deepcopy(
k=num_requests - len(requests)) random.choices(requests, k=num_requests - len(requests))
)
for i in range(len(additional)):
req = additional[i]
req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.",
num_requests) num_requests)
...@@ -334,6 +349,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -334,6 +349,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Enforce range_ratio < 1 # Enforce range_ratio < 1
...@@ -391,6 +407,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -391,6 +407,7 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)) ))
return requests return requests
...@@ -432,9 +449,11 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -432,9 +449,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
samples: list = [] samples: list = []
ind = 0
for entry in self.data: for entry in self.data:
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
...@@ -470,8 +489,10 @@ class ShareGPTDataset(BenchmarkDataset): ...@@ -470,8 +489,10 @@ class ShareGPTDataset(BenchmarkDataset):
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(samples, num_requests) ind += 1
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -647,6 +668,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -647,6 +668,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.custom_output_len, output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template, skip_chat_template=args.custom_skip_chat_template,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
...@@ -660,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -660,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=False, return_prompt_formatted=False,
request_id_prefix=args.request_id_prefix,
) )
else: else:
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
...@@ -671,6 +694,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -671,6 +694,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
prefix_len=args.sonnet_prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
return_prompt_formatted=True, return_prompt_formatted=True,
request_id_prefix=args.request_id_prefix,
) )
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
...@@ -730,6 +754,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -730,6 +754,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
output_len=args.hf_output_len, output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
) )
else: else:
...@@ -741,11 +766,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -741,11 +766,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
request_id_prefix=args.request_id_prefix,
), ),
"burstgpt": "burstgpt":
lambda: BurstGPTDataset(random_seed=args.seed, lambda: BurstGPTDataset(random_seed=args.seed,
dataset_path=args.dataset_path). dataset_path=args.dataset_path).
sample(tokenizer=tokenizer, num_requests=args.num_prompts), sample(tokenizer=tokenizer, num_requests=args.num_prompts,
request_id_prefix=args.request_id_prefix,),
"random": "random":
lambda: RandomDataset(random_seed=args.seed, lambda: RandomDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample( dataset_path=args.dataset_path).sample(
...@@ -755,6 +782,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -755,6 +782,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
input_len=args.random_input_len, input_len=args.random_input_len,
output_len=args.random_output_len, output_len=args.random_output_len,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
), ),
"prefix_repetition": "prefix_repetition":
lambda: PrefixRepetitionRandomDataset( lambda: PrefixRepetitionRandomDataset(
...@@ -766,6 +794,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -766,6 +794,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
suffix_len=args.prefix_repetition_suffix_len, suffix_len=args.prefix_repetition_suffix_len,
num_prefixes=args.prefix_repetition_num_prefixes, num_prefixes=args.prefix_repetition_num_prefixes,
output_len=args.prefix_repetition_output_len, output_len=args.prefix_repetition_output_len,
request_id_prefix=args.request_id_prefix,
), ),
} }
...@@ -839,10 +868,11 @@ class CustomDataset(BenchmarkDataset): ...@@ -839,10 +868,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
skip_chat_template: bool = False, skip_chat_template: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["prompt"] prompt = item["prompt"]
...@@ -864,8 +894,10 @@ class CustomDataset(BenchmarkDataset): ...@@ -864,8 +894,10 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -909,6 +941,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -909,6 +941,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False, return_prompt_formatted: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
...@@ -934,6 +967,7 @@ class SonnetDataset(BenchmarkDataset): ...@@ -934,6 +967,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines] prefix_lines = self.data[:num_prefix_lines]
samples = [] samples = []
ind = 0
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices(self.data, extra_lines = random.choices(self.data,
k=num_input_lines - num_prefix_lines) k=num_input_lines - num_prefix_lines)
...@@ -949,7 +983,9 @@ class SonnetDataset(BenchmarkDataset): ...@@ -949,7 +983,9 @@ class SonnetDataset(BenchmarkDataset):
if return_prompt_formatted else prompt, if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(ind),
)) ))
ind += 1
return samples return samples
...@@ -1000,6 +1036,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -1000,6 +1036,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int, num_requests: int,
max_loras: Optional[int] = None, max_loras: Optional[int] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
samples = [] samples = []
...@@ -1020,6 +1057,7 @@ class BurstGPTDataset(BenchmarkDataset): ...@@ -1020,6 +1057,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
request_id=request_id_prefix + str(i),
)) ))
return samples return samples
...@@ -1075,11 +1113,13 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -1075,11 +1113,13 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter( filtered_data = self.data.filter(
lambda x: len(x["conversations"]) >= 2) lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
ind = 0
dynamic_output = output_len is None dynamic_output = output_len is None
for item in filtered_data: for item in filtered_data:
...@@ -1111,8 +1151,11 @@ class ConversationDataset(HuggingFaceDataset): ...@@ -1111,8 +1151,11 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1141,12 +1184,13 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -1141,12 +1184,13 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
...@@ -1168,8 +1212,10 @@ class VisionArenaDataset(HuggingFaceDataset): ...@@ -1168,8 +1212,10 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1198,11 +1244,12 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -1198,11 +1244,12 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = f"{item['input']}\n\n{item['instruction']} Just output \ prompt = f"{item['input']}\n\n{item['instruction']} Just output \
...@@ -1224,8 +1271,10 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -1224,8 +1271,10 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1255,13 +1304,14 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -1255,13 +1304,14 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt = item["turns"][0] prompt = item["turns"][0]
...@@ -1282,8 +1332,10 @@ class MTBenchDataset(HuggingFaceDataset): ...@@ -1282,8 +1332,10 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
request_id=request_id_prefix + str(i),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1305,8 +1357,10 @@ class AIMODataset(HuggingFaceDataset): ...@@ -1305,8 +1357,10 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs) -> list: **kwargs) -> list:
sampled_requests = [] sampled_requests = []
ind = 0
dynamic_output = output_len is None dynamic_output = output_len is None
for item in self.data: for item in self.data:
...@@ -1331,8 +1385,12 @@ class AIMODataset(HuggingFaceDataset): ...@@ -1331,8 +1385,12 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
request_id=request_id_prefix + str(ind),
)) ))
self.maybe_oversample_requests(sampled_requests, num_requests) ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1403,13 +1461,14 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1403,13 +1461,14 @@ class NextEditPredictionDataset(HuggingFaceDataset):
} }
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
request_id_prefix: str = "",
**kwargs): **kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path) self.dataset_path)
if formatting_prompt_func is None: if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}") raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = [] samples = []
for sample in self.data: for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample) sample = formatting_prompt_func(sample)
samples.append( samples.append(
SampleRequest( SampleRequest(
...@@ -1417,10 +1476,11 @@ class NextEditPredictionDataset(HuggingFaceDataset): ...@@ -1417,10 +1476,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
prompt_len=len(tokenizer(sample["prompt"]).input_ids), prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len( expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids), tokenizer(sample["expected_output"]).input_ids),
request_id=request_id_prefix + str(i),
)) ))
if len(samples) >= num_requests: if len(samples) >= num_requests:
break break
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples return samples
...@@ -1470,6 +1530,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1470,6 +1530,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = (output_len
...@@ -1477,6 +1538,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1477,6 +1538,7 @@ class ASRDataset(HuggingFaceDataset):
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
ind = 0
skipped = 0 skipped = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -1496,7 +1558,9 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1496,7 +1558,9 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
request_id=request_id_prefix + str(ind),
)) ))
ind += 1
if skipped: if skipped:
logger.warning( logger.warning(
"%d samples discarded from dataset due to" "%d samples discarded from dataset due to"
...@@ -1504,7 +1568,8 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1504,7 +1568,8 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", " what Whisper supports.",
skipped, skipped,
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1541,11 +1606,13 @@ class MLPerfDataset(HuggingFaceDataset): ...@@ -1541,11 +1606,13 @@ class MLPerfDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
output_len: Optional[int] = None, output_len: Optional[int] = None,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
# Force dynamic output length based on reference completion. # Force dynamic output length based on reference completion.
dynamic_output = output_len is None dynamic_output = output_len is None
sampled_requests: list[SampleRequest] = [] sampled_requests: list[SampleRequest] = []
ind = 0
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
...@@ -1580,10 +1647,13 @@ class MLPerfDataset(HuggingFaceDataset): ...@@ -1580,10 +1647,13 @@ class MLPerfDataset(HuggingFaceDataset):
prompt=prompt_formatted, prompt=prompt_formatted,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=expected_output_len, expected_output_len=expected_output_len,
request_id=request_id_prefix + str(ind),
) )
) )
ind += 1
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)
return sampled_requests return sampled_requests
...@@ -1616,6 +1686,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): ...@@ -1616,6 +1686,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
suffix_len: int = DEFAULT_SUFFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN,
num_prefixes: int = DEFAULT_NUM_PREFIXES, num_prefixes: int = DEFAULT_NUM_PREFIXES,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
......
...@@ -31,6 +31,7 @@ class RequestFuncInput: ...@@ -31,6 +31,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: Optional[str] = None language: Optional[str] = None
request_id: Optional[str] = None
@dataclass @dataclass
...@@ -87,6 +88,8 @@ async def async_request_openai_completions( ...@@ -87,6 +88,8 @@ async def async_request_openai_completions(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -210,6 +213,8 @@ async def async_request_openai_chat_completions( ...@@ -210,6 +213,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -311,6 +316,8 @@ async def async_request_openai_audio( ...@@ -311,6 +316,8 @@ async def async_request_openai_audio(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
} }
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
# Send audio file # Send audio file
def to_bytes(y, sr): def to_bytes(y, sr):
......
...@@ -478,11 +478,12 @@ async def benchmark( ...@@ -478,11 +478,12 @@ async def benchmark(
"timestamp": timestamp "timestamp": timestamp
}) })
last_int_rps = current_int_rps last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content = ( prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt, request.prompt,
request.prompt_len, request.prompt_len,
request.expected_output_len, request.expected_output_len,
request.multi_modal_data, request.multi_modal_data,
request.request_id,
) )
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
if lora_modules: if lora_modules:
...@@ -498,7 +499,8 @@ async def benchmark( ...@@ -498,7 +499,8 @@ async def benchmark(
logprobs=logprobs, logprobs=logprobs,
multi_modal_content=mm_content, multi_modal_content=mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
request_id=request_id,)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input,
...@@ -865,6 +867,14 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -865,6 +867,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve", "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
) )
parser.add_argument(
"--request-id-prefix",
type=str,
required=False,
default="benchmark-serving",
help="Specify the prefix of request id.",
)
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
sampling_group.add_argument( sampling_group.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment