"tests/python/common/test_merge.py" did not exist on "98325b1097877b93dc872727d22ce2f402666e8f"
Unverified Commit a027a9b4 authored by Sundara Raman Ramachandran's avatar Sundara Raman Ramachandran Committed by GitHub
Browse files

[Generative Score API] Optimization to Remove Decode. (#8840)

parent 9e426466
This diff is collapsed.
...@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU # hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0 hicache_consumer_index: int = 0
...@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs), return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(
req.sampling_params.max_new_tokens == 0 for req in reqs
),
chunked_req=chunked_req, chunked_req=chunked_req,
) )
...@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch, is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
) )
def _evict_tree_cache_if_needed(self, num_tokens: int): def _evict_tree_cache_if_needed(self, num_tokens: int):
......
...@@ -1466,8 +1466,9 @@ class Scheduler( ...@@ -1466,8 +1466,9 @@ class Scheduler(
if self.last_batch.batch_size() < last_bs: if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False self.running_batch.batch_is_full = False
# Merge the new batch into the running batch # Merge the new batch into the running batch.
if not self.last_batch.is_empty(): # For prefill-only batch, we can avoid going through decoding step.
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
if self.running_batch.is_empty(): if self.running_batch.is_empty():
self.running_batch = self.last_batch self.running_batch = self.last_batch
else: else:
......
...@@ -699,7 +699,7 @@ class TokenizerManager: ...@@ -699,7 +699,7 @@ class TokenizerManager:
# Process all requests # Process all requests
tokenized_objs = [] tokenized_objs = []
for i, req in enumerate(requests): for i, req in enumerate(requests):
self._validate_token_len(obj[i], input_ids_list[i]) self._validate_one_request(obj[i], input_ids_list[i])
tokenized_objs.append( tokenized_objs.append(
self._create_tokenized_object( self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None req, req.text, input_ids_list[i], None, None
...@@ -1892,6 +1892,13 @@ class TokenizerManager: ...@@ -1892,6 +1892,13 @@ class TokenizerManager:
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
) )
batch_request = GenerateReqInput(
token_ids_logprob=label_token_ids,
return_logprob=True,
stream=False,
sampling_params={"max_new_tokens": 0},
)
# Handle string or tokenized query/items # Handle string or tokenized query/items
if isinstance(query, str) and ( if isinstance(query, str) and (
isinstance(items, str) isinstance(items, str)
...@@ -1903,13 +1910,9 @@ class TokenizerManager: ...@@ -1903,13 +1910,9 @@ class TokenizerManager:
prompts = [f"{item}{query}" for item in items_list] prompts = [f"{item}{query}" for item in items_list]
else: else:
prompts = [f"{query}{item}" for item in items_list] prompts = [f"{query}{item}" for item in items_list]
batch_request = GenerateReqInput(
text=prompts, batch_request.text = prompts
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
elif ( elif (
isinstance(query, list) isinstance(query, list)
and isinstance(items, list) and isinstance(items, list)
...@@ -1921,13 +1924,8 @@ class TokenizerManager: ...@@ -1921,13 +1924,8 @@ class TokenizerManager:
input_ids_list = [item + query for item in items] input_ids_list = [item + query for item in items]
else: else:
input_ids_list = [query + item for item in items] input_ids_list = [query + item for item in items]
batch_request = GenerateReqInput(
input_ids=input_ids_list, batch_request.input_ids = input_ids_list
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
else: else:
raise ValueError( raise ValueError(
"Invalid combination of query/items types for score_request." "Invalid combination of query/items types for score_request."
...@@ -1939,9 +1937,20 @@ class TokenizerManager: ...@@ -1939,9 +1937,20 @@ class TokenizerManager:
for result in results: for result in results:
# Get logprobs for each token # Get logprobs for each token
logprobs = {} logprobs = {}
for logprob, token_id, _ in result["meta_info"].get(
"output_token_ids_logprobs", [] # For scoring requests, we read from output_token_ids_logprobs since we want
)[0]: # the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Throw an error here if output_logprobs is None
if output_logprobs is None:
raise RuntimeError(
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
"This usually indicates a problem with the scoring request or the backend output."
)
for logprob, token_id, _ in output_logprobs[0]:
if token_id in label_token_ids: if token_id in label_token_ids:
logprobs[token_id] = logprob logprobs[token_id] = logprob
......
...@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase): ...@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
1.0, sum(score_list), 6, "Scores should sum to 1" 1.0, sum(score_list), 6, "Scores should sum to 1"
) )
def test_score_request_construction(self):
"""Test that scoring requests are constructed to avoid decode phase."""
from unittest.mock import patch
# Capture the internal request to verify optimization
captured_requests = []
original_gen = self.engine.tokenizer_manager.generate_request
async def mock_generate_request(req, request=None):
captured_requests.append(req)
async for result in original_gen(req, request):
yield result
# Patch the generate_request method
with patch.object(
self.engine.tokenizer_manager,
"generate_request",
side_effect=mock_generate_request,
):
# Run a scoring request
query = "What is the capital of"
items = ["France", "Germany"]
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
# Verify we got results
self.assertEqual(len(scores), len(items))
# Verify the captured request has decode-avoiding properties
self.assertEqual(len(captured_requests), 1)
request = captured_requests[0]
# Key assertions for decode phase avoidance:
# 1. max_new_tokens should be 0 (prevents token generation)
# Handle both single and batch request cases
if isinstance(request.sampling_params, dict):
max_new_tokens = request.sampling_params.get("max_new_tokens", 0)
elif isinstance(request.sampling_params, list):
# For batch requests, check the first item
max_new_tokens = request.sampling_params[0].get("max_new_tokens", 0)
else:
max_new_tokens = getattr(request.sampling_params, "max_new_tokens", 0)
self.assertEqual(
max_new_tokens, 0, "max_new_tokens should be 0 to avoid decode phase"
)
# 2. Should have token_ids_logprob for scoring
# Handle both single and batch request cases
if (
isinstance(request.token_ids_logprob, list)
and len(request.token_ids_logprob) > 0
and isinstance(request.token_ids_logprob[0], list)
):
# Batch case: token_ids_logprob is a list of lists
# Each item in the batch should have the same label_token_ids
for item_token_ids in request.token_ids_logprob:
self.assertEqual(
item_token_ids,
label_token_ids,
"Each batch item should have label_token_ids for scoring",
)
else:
# Single request case
self.assertEqual(
request.token_ids_logprob,
label_token_ids,
"Should have label_token_ids for scoring",
)
# 3. Should request logprobs but not stream
self.assertTrue(
request.return_logprob, "Should request logprobs for scoring"
)
self.assertFalse(request.stream, "Scoring requests should not stream")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
"""
Unit tests for enable_tokenizer_batch_encode feature.
This tests the batch tokenization functionality which allows processing
multiple text inputs in a single batch for improved performance.
Usage:
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path
"""
import asyncio
import unittest
from typing import List
from unittest.mock import AsyncMock, Mock, call, patch
from sglang.srt.managers.io_struct import GenerateReqInput, TokenizedGenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestTokenizerBatchEncode(unittest.TestCase):
"""Test cases for tokenizer batch encoding validation and setup."""
def setUp(self):
"""Set up test fixtures."""
self.server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
enable_tokenizer_batch_encode=True,
)
self.port_args = PortArgs.init_new(self.server_args)
with patch("zmq.asyncio.Context"), patch(
"sglang.srt.utils.get_zmq_socket"
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
mock_tokenizer.return_value = Mock(vocab_size=32000)
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
def test_batch_encode_enabled(self):
"""Test that batch encoding is enabled when configured."""
self.assertTrue(self.server_args.enable_tokenizer_batch_encode)
def test_batch_encode_disabled(self):
"""Test that batch encoding can be disabled."""
server_args_disabled = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
enable_tokenizer_batch_encode=False,
)
self.assertFalse(server_args_disabled.enable_tokenizer_batch_encode)
def test_multimodal_input_validation(self):
"""Test that multimodal inputs are rejected in batch mode."""
req = GenerateReqInput(text="test", image_data=["dummy"])
req.contains_mm_input = Mock(return_value=True)
batch_obj = Mock()
batch_obj.__getitem__ = lambda self, i: req
self.tokenizer_manager.is_generation = True
with self.assertRaises(ValueError) as cm:
self.tokenizer_manager._validate_batch_tokenization_constraints(
1, batch_obj
)
self.assertIn("multimodal", str(cm.exception))
def test_pretokenized_input_validation(self):
"""Test that pre-tokenized inputs are rejected in batch mode."""
req = GenerateReqInput(input_ids=[1, 2, 3])
batch_obj = Mock()
batch_obj.__getitem__ = lambda self, i: req
with self.assertRaises(ValueError) as cm:
self.tokenizer_manager._validate_batch_tokenization_constraints(
1, batch_obj
)
self.assertIn("pre-tokenized", str(cm.exception))
def test_input_embeds_validation(self):
"""Test that input embeds are rejected in batch mode."""
req = GenerateReqInput(input_embeds=[0.1, 0.2])
batch_obj = Mock()
batch_obj.__getitem__ = lambda self, i: req
with self.assertRaises(ValueError) as cm:
self.tokenizer_manager._validate_batch_tokenization_constraints(
1, batch_obj
)
self.assertIn("input_embeds", str(cm.exception))
def test_valid_text_only_requests_pass_validation(self):
"""Test that valid text-only requests pass validation."""
# Create valid requests (text-only)
requests = []
for i in range(3):
req = GenerateReqInput(text=f"test text {i}")
req.contains_mm_input = Mock(return_value=False)
requests.append(req)
batch_obj = Mock()
batch_obj.__getitem__ = Mock(side_effect=lambda i: requests[i])
# Should not raise any exception
try:
self.tokenizer_manager._validate_batch_tokenization_constraints(
3, batch_obj
)
except Exception as e:
self.fail(f"Validation failed for valid text-only requests: {e}")
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment