Unverified Commit eb7318f1 authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

support tokenized batch request (#11091)

parent 6058fb52
...@@ -759,6 +759,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -759,6 +759,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"""Handle batch tokenization for text inputs only.""" """Handle batch tokenization for text inputs only."""
logger.debug(f"Starting batch tokenization for {batch_size} text requests") logger.debug(f"Starting batch tokenization for {batch_size} text requests")
# If batch does not have text nothing to tokenize
# so lets construct the return object
if not self._batch_has_text(batch_size, obj):
# All requests already have input_ids, no need to tokenize
return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)]
self._validate_batch_tokenization_constraints(batch_size, obj)
# Collect requests and texts # Collect requests and texts
requests = [obj[i] for i in range(batch_size)] requests = [obj[i] for i in range(batch_size)]
texts = [req.text for req in requests] texts = [req.text for req in requests]
...@@ -808,6 +816,30 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -808,6 +816,30 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
) )
def _batch_has_text(
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
) -> bool:
"""Check if any request in the batch contains text input."""
for i in range(batch_size):
if obj[i].text:
return True
elif self.is_generation and obj[i].contains_mm_input():
return True
return False
def _should_use_batch_tokenization(self, batch_size, requests) -> bool:
"""Return True if we should run the tokenizer in batch mode.
Current policy:
- Respect explicit server flag `enable_tokenizer_batch_encode`.
- Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization.
"""
return batch_size > 0 and (
self.server_args.enable_tokenizer_batch_encode
or not self._batch_has_text(batch_size, requests)
)
def _send_one_request( def _send_one_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -942,13 +974,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -942,13 +974,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
generators = [] generators = []
rids = [] rids = []
if getattr(obj, "parallel_sample_num", 1) == 1: if getattr(obj, "parallel_sample_num", 1) == 1:
if self.server_args.enable_tokenizer_batch_encode: if self._should_use_batch_tokenization(batch_size, obj):
# Validate batch tokenization constraints
self._validate_batch_tokenization_constraints(batch_size, obj)
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
# Send as a single batched request
self._send_batch_request(obj, tokenized_objs, created_time) self._send_batch_request(obj, tokenized_objs, created_time)
# Set up generators for each request in the batch # Set up generators for each request in the batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment