Unverified Commit b68c4c07 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

fix: force max new tokens to be 1 for embedding request (#1019)

parent e712837d
......@@ -195,7 +195,8 @@ class EmbeddingReqInput:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {"max_new_tokens": 1}
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = (
......@@ -207,9 +208,9 @@ class EmbeddingReqInput:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [
{"max_new_tokens": 1} for _ in range(self.batch_size)
]
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass
......
......@@ -44,7 +44,9 @@ class TestEmbeddingModels(unittest.TestCase):
torch_dtype=torch_dtype,
is_generation_model=False,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts)
srt_outputs = srt_runner.forward(
prompts,
)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment