Commit 1d6cfb11 authored by 王敏's avatar 王敏
Browse files

[fix]修复test_long_context中报错问题,单测依然无法通过,nv也是同样的问题

parent 70c661da
...@@ -81,6 +81,9 @@ def generate( ...@@ -81,6 +81,9 @@ def generate(
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
): ):
prompts, sampling_param, lora_request = inputs prompts, sampling_param, lora_request = inputs
max_model_len = llm.llm_engine.model_config.max_model_len
if len(prompts) >= max_model_len:
prompts = prompts[:max_model_len-1]
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
return outputs[0].outputs[0].text.strip() return outputs[0].outputs[0].text.strip()
...@@ -89,8 +92,11 @@ def batched_generate( ...@@ -89,8 +92,11 @@ def batched_generate(
llm: vllm.LLM, llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
): ):
max_model_len = llm.llm_engine.model_config.max_model_len
for input in inputs: for input in inputs:
prompt, sampling_param, lora_req = input prompt, sampling_param, lora_req = input
if len(prompt) >= max_model_len:
prompt = prompt[:max_model_len-1]
# Add requests to the engine and run the engine # Add requests to the engine and run the engine
llm._validate_and_add_requests(prompt, llm._validate_and_add_requests(prompt,
sampling_param, sampling_param,
...@@ -111,7 +117,7 @@ def lora_llm(long_context_infos): ...@@ -111,7 +117,7 @@ def lora_llm(long_context_infos):
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf", llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
enable_lora=True, enable_lora=True,
max_num_seqs=16, max_num_seqs=16,
max_loras=2, max_loras=8,
long_lora_scaling_factors=tuple(scaling_factors), long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8, max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4, tensor_parallel_size=4,
......
...@@ -83,7 +83,9 @@ class LoRARequest( ...@@ -83,7 +83,9 @@ class LoRARequest(
and comparison lora adapter across engines. and comparison lora adapter across engines.
""" """
return isinstance(value, return isinstance(value,
self.__class__) and self.lora_name == value.lora_name self.__class__) and self.lora_name == value.lora_name and \
self.lora_int_id == value.lora_int_id and \
self.lora_path == value.lora_path
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment