Unverified Commit 7e183767 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[misc] Add LoRA to benchmark_serving (#12898)


Signed-off-by: default avatarVarun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varun@neuralmagic.com>
parent 2880e21e
...@@ -537,6 +537,7 @@ async def benchmark( ...@@ -537,6 +537,7 @@ async def benchmark(
ignore_eos: bool, ignore_eos: bool,
goodput_config_dict: Dict[str, float], goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
lora_modules: Optional[List[str]],
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -562,6 +563,7 @@ async def benchmark( ...@@ -562,6 +563,7 @@ async def benchmark(
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
) )
test_output = await request_func(request_func_input=test_input) test_output = await request_func(request_func_input=test_input)
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
...@@ -570,6 +572,11 @@ async def benchmark( ...@@ -570,6 +572,11 @@ async def benchmark(
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) for _ in range(len(input_requests))])
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, profile_input = RequestFuncInput(model=model_id,
...@@ -616,8 +623,13 @@ async def benchmark( ...@@ -616,8 +623,13 @@ async def benchmark(
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id, req_model_id, req_model_name = model_id, model_name
model_name=model_name, if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt, prompt=prompt,
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=prompt_len,
...@@ -900,6 +912,7 @@ def main(args: argparse.Namespace): ...@@ -900,6 +912,7 @@ def main(args: argparse.Namespace):
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
)) ))
# Save config and results to json # Save config and results to json
...@@ -1237,5 +1250,12 @@ if __name__ == "__main__": ...@@ -1237,5 +1250,12 @@ if __name__ == "__main__":
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ") "same as the ``--model`` argument. ")
parser.add_argument("--lora-modules",
nargs='+',
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment