Unverified Commit eff468dd authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix test_embedding_models prompt length too long's bug (#2015)

parent a1bd7190
......@@ -17,6 +17,7 @@ import multiprocessing as mp
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
......@@ -34,6 +35,24 @@ class TestEmbeddingModels(unittest.TestCase):
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def _truncate_prompts(self, prompts, model_path):
config = AutoConfig.from_pretrained(model_path)
max_length = getattr(config, "max_position_embeddings", 2048)
tokenizer = AutoTokenizer.from_pretrained(model_path)
truncated_prompts = []
for prompt in prompts:
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
if len(tokens.input_ids[0]) > max_length:
truncated_text = tokenizer.decode(
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
)
truncated_prompts.append(truncated_text)
else:
truncated_prompts.append(prompt)
return truncated_prompts
def assert_close_prefill_logits(
self,
prompts,
......@@ -42,12 +61,14 @@ class TestEmbeddingModels(unittest.TestCase):
torch_dtype,
prefill_tolerance,
) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path)
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_outputs = hf_runner.forward(prompts)
hf_outputs = hf_runner.forward(truncated_prompts)
with SRTRunner(
model_path,
......@@ -55,7 +76,7 @@ class TestEmbeddingModels(unittest.TestCase):
torch_dtype=torch_dtype,
model_type="embedding",
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
srt_outputs = srt_runner.forward(truncated_prompts)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment