Unverified Commit beac202b authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add lora_path argument to bench_multiturn.py (#10092)

parent 21b9a4b4
......@@ -130,6 +130,12 @@ def parse_args():
help="Tag of a certain run in the log file",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--lora-path",
type=str,
default="",
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
)
return parser.parse_args()
......@@ -205,7 +211,7 @@ async def async_request_sglang_generate(
return output
def gen_payload(prompt, output_len):
def gen_payload(prompt, output_len, lora_path=""):
payload = {
"text": prompt,
"sampling_params": {
......@@ -215,7 +221,7 @@ def gen_payload(prompt, output_len):
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"lora_path": lora_path,
"return_logprob": False,
"logprob_start_len": -1,
}
......@@ -303,7 +309,12 @@ class WorkloadGenerator:
)
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
(
i,
gen_payload(
self.candidate_inputs[i], args.output_length, args.lora_path
),
)
for i in range(args.num_clients)
]
self.client_records = {
......@@ -399,6 +410,7 @@ class WorkloadGenerator:
gen_payload(
self.client_records[client_id]["history"],
self.output_length,
args.lora_path,
),
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment