Unverified Commit 01017d4c authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

Support LoRA in Completion API (#2243)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent 94e167ea
......@@ -51,6 +51,7 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
lora_name: str
extra_request_body: Dict[str, Any]
......@@ -162,6 +163,7 @@ async def async_request_openai_completions(
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
"lora_path": request_func_input.lora_name,
**request_func_input.extra_request_body,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
......@@ -319,6 +321,7 @@ async def async_request_sglang_generate(
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
"lora_path": request_func_input.lora_name,
**request_func_input.extra_request_body,
}
headers = {}
......@@ -884,6 +887,7 @@ async def benchmark(
request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool,
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
):
......@@ -909,6 +913,7 @@ async def benchmark(
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
......@@ -942,6 +947,7 @@ async def benchmark(
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
tasks.append(
......@@ -1247,6 +1253,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
......@@ -1267,6 +1274,7 @@ def run_benchmark(args_: argparse.Namespace):
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
......@@ -1451,5 +1459,11 @@ if __name__ == "__main__":
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
args = parser.parse_args()
run_benchmark(args)
......@@ -486,6 +486,7 @@ def v1_generate_request(
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
lora_paths = []
for request in all_requests:
# NOTE: with openai API, the prompt's logprobs are always not computed
......@@ -496,6 +497,7 @@ def v1_generate_request(
)
prompts.append(request.prompt)
lora_paths.append(request.lora_path)
if request.echo and request.logprobs:
current_logprob_start_len = 0
else:
......@@ -534,6 +536,7 @@ def v1_generate_request(
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
lora_paths = lora_paths[0]
else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
......@@ -549,6 +552,7 @@ def v1_generate_request(
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
lora_path=lora_paths,
)
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
......
......@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
json_schema: Optional[str] = None
......
......@@ -567,6 +567,7 @@ def run_bench_serving(
disable_tqdm=False,
disable_stream=disable_stream,
disable_ignore_eos=False,
lora_name=None,
extra_request_body=None,
profile=None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment