Unverified Commit 7feba415 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix failed ci tests on long prompts; Better error messages for embedding models (#1700)

parent 30ee3630
...@@ -56,6 +56,9 @@ class GenerateReqInput: ...@@ -56,6 +56,9 @@ class GenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
......
...@@ -150,9 +150,13 @@ class TokenizerManager: ...@@ -150,9 +150,13 @@ class TokenizerManager:
while self.model_update_lock.locked(): while self.model_update_lock.locked():
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
)
obj.post_init() obj.post_init()
is_single = obj.is_single is_single = obj.is_single
if is_single: if is_single:
async for response in self._handle_single_request(obj, request): async for response in self._handle_single_request(obj, request):
yield response yield response
......
...@@ -542,8 +542,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): ...@@ -542,8 +542,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process(pid, including_parent=False) kill_child_process(pid, including_parent=False)
return return
print(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("ready") pipe_finish_writer.send("ready")
......
...@@ -40,20 +40,23 @@ class ModelCase: ...@@ -40,20 +40,23 @@ class ModelCase:
prefill_tolerance: float = 5e-2 prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2 decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1 rouge_l_tolerance: float = 1
skip_long_prompt: bool = False
# Popular models that run on the CI # Popular models that run on the CI
CI_MODELS = [ CI_MODELS = [
ModelCase("meta-llama/Llama-3.1-8B-Instruct"), ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"), ModelCase(
"google/gemma-2-2b", skip_long_prompt=True
), # There is a bug with new transformers library. This can only run with transformers==4.44
] ]
# All other models that do not run on the CI # All other models that do not run on the CI
ALL_OTHER_MODELS = [ ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2-1.5B"), ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -136,8 +139,15 @@ class TestGenerationModels(unittest.TestCase): ...@@ -136,8 +139,15 @@ class TestGenerationModels(unittest.TestCase):
def test_ci_models(self): def test_ci_models(self):
for model_case in CI_MODELS: for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs( self.assert_close_logits_and_output_strs(
DEFAULT_PROMPTS, model_case, torch_dtype prompts, model_case, torch_dtype
) )
def test_others(self): def test_others(self):
...@@ -152,13 +162,9 @@ class TestGenerationModels(unittest.TestCase): ...@@ -152,13 +162,9 @@ class TestGenerationModels(unittest.TestCase):
): ):
continue continue
# Skip long prompts for models that does not have a long context # Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS prompts = DEFAULT_PROMPTS
if model_case.model_path in [ if model_case.skip_long_prompt:
"HuggingFaceTB/SmolLM-135M-Instruct",
"allenai/OLMo-1B-0724-hf",
"google/gemma-2-2b", # There is a bug with new transformers library. This can only run with transformers==4.44
]:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close # Assert the logits and output strs are close
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment