Unverified Commit f290bd43 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Bugfix] Fix embedding model hangs with `--enable-metrics` (#2822)

parent 8f157893
...@@ -128,7 +128,7 @@ class ModelConfig: ...@@ -128,7 +128,7 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
# Veirfy quantization # Verify quantization
self._verify_quantization() self._verify_quantization()
# Cache attributes # Cache attributes
......
...@@ -688,7 +688,7 @@ class TokenizerManager: ...@@ -688,7 +688,7 @@ class TokenizerManager:
if self.enable_metrics: if self.enable_metrics:
completion_tokens = ( completion_tokens = (
recv_obj.completion_tokens[i] recv_obj.completion_tokens[i]
if recv_obj.completion_tokens if getattr(recv_obj, "completion_tokens", None)
else 0 else 0
) )
...@@ -716,7 +716,11 @@ class TokenizerManager: ...@@ -716,7 +716,11 @@ class TokenizerManager:
time.time() - state.created_time time.time() - state.created_time
) )
# Compute time_per_output_token for the non-streaming case # Compute time_per_output_token for the non-streaming case
if not state.obj.stream and completion_tokens >= 1: if (
hasattr(state.obj, "stream")
and not state.obj.stream
and completion_tokens >= 1
):
self.metrics_collector.observe_time_per_output_token( self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) (time.time() - state.created_time)
/ completion_tokens / completion_tokens
......
...@@ -724,7 +724,7 @@ class ModelRunner: ...@@ -724,7 +724,7 @@ class ModelRunner:
elif forward_batch.forward_mode.is_idle(): elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch) return self.forward_idle(forward_batch)
else: else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
def sample( def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
......
...@@ -14,6 +14,7 @@ import openai ...@@ -14,6 +14,7 @@ import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
...@@ -675,5 +676,45 @@ class TestOpenAIServerEBNF(unittest.TestCase): ...@@ -675,5 +676,45 @@ class TestOpenAIServerEBNF(unittest.TestCase):
), "Function name should be add for the above response" ), "Function name should be add for the above response"
class TestOpenAIEmbedding(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Configure embedding-specific args
other_args = ["--is-embedding", "--enable-metrics"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_embedding_single(self):
"""Test single embedding request"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input="Hello world")
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_batch(self):
"""Test batch embedding request"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model, input=["Hello world", "Test text"]
)
self.assertEqual(len(response.data), 2)
self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment