Unverified Commit 1cb4da5c authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] the issue of random order when input is a list (#1199)

parent e61d13ac
...@@ -437,13 +437,13 @@ class TokenizerManager: ...@@ -437,13 +437,13 @@ class TokenizerManager:
is_stream = hasattr(obj, "stream") and obj.stream is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [] output_list = [None] * len(tasks)
while tasks: while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done: for task in done:
gen_index = tasks.index(task) cur_index = tasks.index(task)
try: try:
result = task.result() result = task.result()
...@@ -451,14 +451,14 @@ class TokenizerManager: ...@@ -451,14 +451,14 @@ class TokenizerManager:
if is_stream: if is_stream:
yield result yield result
else: else:
output_list.append(result) output_list[result["index"]] = result
tasks[gen_index] = asyncio.create_task( tasks[cur_index] = asyncio.create_task(
generators[gen_index].__anext__() generators[cur_index].__anext__()
) )
except StopAsyncIteration: except StopAsyncIteration:
del generators[gen_index] del generators[cur_index]
del tasks[gen_index] del tasks[cur_index]
if not is_stream: if not is_stream:
yield output_list yield output_list
......
...@@ -591,7 +591,7 @@ class Runtime: ...@@ -591,7 +591,7 @@ class Runtime:
def generate( def generate(
self, self,
prompt: str, prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None, sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
...@@ -612,7 +612,7 @@ class Runtime: ...@@ -612,7 +612,7 @@ class Runtime:
def encode( def encode(
self, self,
prompt: str, prompt: Union[str, List[str]],
): ):
json_data = { json_data = {
"text": prompt, "text": prompt,
......
...@@ -28,10 +28,10 @@ from sglang.srt.server import Runtime ...@@ -28,10 +28,10 @@ from sglang.srt.server import Runtime
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt # the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is", # "The capital of France is",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kindom is", "The capital of the United Kindom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
"AI is a field of computer science focused on", "AI is a field of computer science focused on",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
] ]
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities from sglang.test.test_utils import get_similarities
MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase):
model_path, model_path,
tp_size, tp_size,
torch_dtype, torch_dtype,
long_context_tolerance,
) -> None: ) -> None:
with HFRunner( with HFRunner(
model_path, torch_dtype=torch_dtype, is_generation_model=False model_path, torch_dtype=torch_dtype, is_generation_model=False
...@@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
print("max similarity diff", torch.max(abs(similarities - 1))) print("similarity diff", abs(similarity - 1))
if hf_logits.shape[0] <= 100: if len(prompts[i]) <= 1000:
tolerance = 1e-2 tolerance = 1e-5
assert torch.all( else:
abs(similarities - 1) < tolerance tolerance = long_context_tolerance
), "embeddings are not all close" assert torch.all(
abs(similarity - 1) < tolerance
), "embeddings are not all close"
def test_prefill_logits(self): def test_prefill_logits(self):
for model, tp_size in MODELS: for model, tp_size, long_context_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES: for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits( self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment