Unverified Commit 35759efa authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support random dataset in bench_serving.py (#669)

parent 8f4b1559
...@@ -273,6 +273,37 @@ def sample_sharegpt_requests( ...@@ -273,6 +273,37 @@ def sample_sharegpt_requests(
return filtered_dataset return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
input_lens = np.random.randint(
int(input_len * range_ratio),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
async def get_request( async def get_request(
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
request_rate: float, request_rate: float,
...@@ -530,13 +561,23 @@ def fire(args: argparse.Namespace): ...@@ -530,13 +561,23 @@ def fire(args: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id) tokenizer = get_tokenizer(tokenizer_id)
assert args.dataset is not None if args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests( input_requests = sample_sharegpt_requests(
dataset_path=args.dataset, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len, fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "random":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
asyncio.run( asyncio.run(
benchmark( benchmark(
...@@ -589,7 +630,14 @@ if __name__ == "__main__": ...@@ -589,7 +630,14 @@ if __name__ == "__main__":
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
) )
parser.add_argument( parser.add_argument(
"--dataset", type=str, default="sharegpt", help="Path to the ShareGPT dataset" "--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
...@@ -613,10 +661,29 @@ if __name__ == "__main__": ...@@ -613,10 +661,29 @@ if __name__ == "__main__":
default=None, default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.", help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
) )
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
type=float, type=float,
default=128.0, default=float("inf"),
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.", "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
) )
......
...@@ -233,7 +233,7 @@ class ModelRunner: ...@@ -233,7 +233,7 @@ class ModelRunner:
return return
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)] batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list) self, max_batch_size_to_capture=max(batch_size_list)
) )
......
...@@ -40,7 +40,7 @@ class GenerateReqInput: ...@@ -40,7 +40,7 @@ class GenerateReqInput:
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if "n" in self.sampling_params and self.sampling_params["n"] != 1: if self.sampling_params.get("n", 1) != 1:
is_single = False is_single = False
else: else:
if self.text is not None: if self.text is not None:
......
...@@ -196,14 +196,14 @@ class TokenizerManager: ...@@ -196,14 +196,14 @@ class TokenizerManager:
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
if is_prefill == False: if is_prefill:
await self._wait_for_prefill_response(event, state, obj, request, rid)
yield input_ids
else:
async for response in self._wait_for_response( async for response in self._wait_for_response(
event, state, obj, rid, request event, state, obj, rid, request
): ):
yield response yield response
else:
await self._wait_for_prefill_response(event, state, obj, request, rid)
yield input_ids
async def _handle_batch_request(self, obj, request): async def _handle_batch_request(self, obj, request):
batch_size = obj.batch_size batch_size = obj.batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment