Unverified Commit cd8d4b9d authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix lora bench (#6302)

parent f194e14f
...@@ -170,6 +170,7 @@ async def benchmark( ...@@ -170,6 +170,7 @@ async def benchmark(
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
output_len=test_output_len, output_len=test_output_len,
lora_name="dummy", # the lora_name argument will not be used lora_name="dummy", # the lora_name argument will not be used
image_data=None,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
test_output = await request_func(request_func_input=test_input) test_output = await request_func(request_func_input=test_input)
...@@ -194,6 +195,7 @@ async def benchmark( ...@@ -194,6 +195,7 @@ async def benchmark(
prompt_len=prompt_len, prompt_len=prompt_len,
output_len=output_len, output_len=output_len,
lora_name="dummy", lora_name="dummy",
image_data=None,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
tasks.append( tasks.append(
......
...@@ -170,9 +170,7 @@ class LoRAManager: ...@@ -170,9 +170,7 @@ class LoRAManager:
dim=0, dim=0,
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1], out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
) )
self.cuda_graph_batch_info.max_len = int( self.cuda_graph_batch_info.max_len = 1
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
)
for i, lora_path in enumerate(forward_batch.lora_paths): for i, lora_path in enumerate(forward_batch.lora_paths):
self.cuda_graph_batch_info.weight_indices[i] = ( self.cuda_graph_batch_info.weight_indices[i] = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment