Unverified Commit cbf23dbb authored by Glen Liu's avatar Glen Liu Committed by GitHub
Browse files

[Feature] add --lora-request-distribution arg to bench_serving.py and support...

[Feature] add --lora-request-distribution arg to bench_serving.py and support skewed and distinct workloads (#12175)
parent 6dade6c3
...@@ -1753,6 +1753,8 @@ async def benchmark( ...@@ -1753,6 +1753,8 @@ async def benchmark(
max_concurrency: Optional[int], max_concurrency: Optional[int],
disable_tqdm: bool, disable_tqdm: bool,
lora_names: List[str], lora_names: List[str],
lora_request_distribution: Optional[str],
lora_zipf_alpha: Optional[float],
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool, profile: bool,
pd_separated: bool = False, pd_separated: bool = False,
...@@ -1893,11 +1895,30 @@ async def benchmark( ...@@ -1893,11 +1895,30 @@ async def benchmark(
else: else:
request_generator = get_request(input_requests, request_rate) request_generator = get_request(input_requests, request_rate)
# Prepare LoRA request distribution parameters
if lora_request_distribution == "distinct":
lora_idx = 0
elif lora_request_distribution == "skewed":
weights = np.array([lora_zipf_alpha**-i for i in range(len(lora_names))])
lora_probs = weights / np.sum(weights)
else:
lora_idx = None
lora_probs = None
pbar = None if disable_tqdm else tqdm(total=pbar_total) pbar = None if disable_tqdm else tqdm(total=pbar_total)
async for request in request_generator: async for request in request_generator:
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1) if lora_request_distribution == "uniform":
lora_name = lora_names[idx] lora_name = random.choice(lora_names)
elif lora_request_distribution == "distinct":
lora_name = lora_names[lora_idx]
lora_idx = (lora_idx + 1) % len(lora_names)
else:
assert (
lora_request_distribution == "skewed"
), f"Unexpected lora_request_distribution: {lora_request_distribution}. Expected 'skewed'."
lora_name = np.random.choice(lora_names, p=lora_probs)
else: else:
lora_name = None lora_name = None
...@@ -2289,6 +2310,15 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -2289,6 +2310,15 @@ def run_benchmark(args_: argparse.Namespace):
not args.tokenize_prompt not args.tokenize_prompt
), "`--tokenize-prompt` not compatible with image dataset" ), "`--tokenize-prompt` not compatible with image dataset"
if args.lora_request_distribution in ["distinct", "skewed"]:
assert (
args.lora_name is not None and len(args.lora_name) > 1
), "More than 1 LoRA adapter must be specified via --lora-name to use 'distinct' or 'skewed' request distribution."
assert (
args.lora_zipf_alpha > 1
), f"Got invalid value for --lora-zipf-alpha of {args.lora_zipf_alpha}. It must be greater than 1."
print(f"{args}\n") print(f"{args}\n")
# Read dataset # Read dataset
...@@ -2302,6 +2332,17 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -2302,6 +2332,17 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "flush_cache"): if not hasattr(args, "flush_cache"):
args.flush_cache = False args.flush_cache = False
# Prepare LoRA arguments
lora_request_distribution = (
args.lora_request_distribution if args.lora_name is not None else None
)
lora_zipf_alpha = (
args.lora_zipf_alpha
if args.lora_name is not None and args.lora_request_distribution == "skewed"
else None
)
return asyncio.run( return asyncio.run(
benchmark( benchmark(
backend=backend, backend=backend,
...@@ -2314,6 +2355,8 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -2314,6 +2355,8 @@ def run_benchmark(args_: argparse.Namespace):
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
lora_names=args.lora_name, lora_names=args.lora_name,
lora_request_distribution=lora_request_distribution,
lora_zipf_alpha=lora_zipf_alpha,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile, profile=args.profile,
pd_separated=args.pd_separated, pd_separated=args.pd_separated,
...@@ -2551,6 +2594,27 @@ if __name__ == "__main__": ...@@ -2551,6 +2594,27 @@ if __name__ == "__main__":
action=LoRAPathAction, action=LoRAPathAction,
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...", help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
) )
parser.add_argument(
"--lora-request-distribution",
type=str,
default="uniform",
choices=[
"uniform",
"distinct",
"skewed",
],
help="What distribution to sample the LoRA adapters specified in --lora-name. Borrowed from the Punica paper. "
"'distinct' distribution means selecting a new LoRA adapter for every request. "
"'skewed' distribution follows the Zipf distribution, where the number of requests "
"to model i specified in --lora-name is α times the number of requests for model i+1, "
"where α > 1.",
)
parser.add_argument(
"--lora-zipf-alpha",
type=float,
default=1.5,
help="The parameter to use for the Zipf distribution when --lora-request-distribution='skewed'.",
)
parser.add_argument( parser.add_argument(
"--prompt-suffix", "--prompt-suffix",
type=str, type=str,
......
...@@ -806,6 +806,8 @@ def get_benchmark_args( ...@@ -806,6 +806,8 @@ def get_benchmark_args(
device="auto", device="auto",
pd_separated: bool = False, pd_separated: bool = False,
lora_name=None, lora_name=None,
lora_request_distribution="uniform",
lora_zipf_alpha=1.5,
): ):
return SimpleNamespace( return SimpleNamespace(
backend="sglang", backend="sglang",
...@@ -834,6 +836,8 @@ def get_benchmark_args( ...@@ -834,6 +836,8 @@ def get_benchmark_args(
apply_chat_template=False, apply_chat_template=False,
profile=None, profile=None,
lora_name=lora_name, lora_name=lora_name,
lora_request_distribution=lora_request_distribution,
lora_zipf_alpha=lora_zipf_alpha,
prompt_suffix="", prompt_suffix="",
device=device, device=device,
pd_separated=pd_separated, pd_separated=pd_separated,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment