Unverified Commit ea0696b9 authored by Sundara Raman Ramachandran's avatar Sundara Raman Ramachandran Committed by GitHub
Browse files

[Performance] Batch Send from Tokenizer Manager. (#9436)

parent 3aec3d4f
...@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput: ...@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput:
dp_balance_id: int = -1 dp_balance_id: int = -1
@dataclass
class BatchTokenizedGenerateReqInput:
# The batch of tokenized requests
batch: List[TokenizedGenerateReqInput]
def __len__(self):
return len(self.batch)
def __getitem__(self, i):
return self.batch[i]
def __iter__(self):
return iter(self.batch)
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
...@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput: ...@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput:
dp_balance_id: int = -1 dp_balance_id: int = -1
@dataclass
class BatchTokenizedEmbeddingReqInput:
# The batch of tokenized embedding requests
batch: List[TokenizedEmbeddingReqInput]
def __len__(self):
return len(self.batch)
def __getitem__(self, i):
return self.batch[i]
def __iter__(self):
return iter(self.batch)
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
# The request id # The request id
......
...@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput ...@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
CloseSessionReqInput, CloseSessionReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
...@@ -510,6 +512,8 @@ class Scheduler( ...@@ -510,6 +512,8 @@ class Scheduler(
[ [
(TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request),
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped), (FlushCacheReqInput, self.flush_cache_wrapped),
(AbortReq, self.abort_request), (AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session), (OpenSessionReqInput, self.open_session),
...@@ -1018,14 +1022,26 @@ class Scheduler( ...@@ -1018,14 +1022,26 @@ class Scheduler(
req req
for req in recv_reqs for req in recv_reqs
if isinstance( if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
BatchTokenizedEmbeddingReqInput,
),
) )
] ]
control_reqs = [ control_reqs = [
req req
for req in recv_reqs for req in recv_reqs
if not isinstance( if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
BatchTokenizedEmbeddingReqInput,
),
) )
] ]
else: else:
...@@ -1253,6 +1269,17 @@ class Scheduler( ...@@ -1253,6 +1269,17 @@ class Scheduler(
else: else:
self._add_request_to_queue(req) self._add_request_to_queue(req)
def handle_batch_generate_request(
self,
recv_req: BatchTokenizedGenerateReqInput,
):
"""Handle optimized batch generate request."""
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
# Process each request in the batch
for tokenized_req in recv_req:
self.handle_generate_request(tokenized_req)
def _add_request_to_queue(self, req: Req): def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter() req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -1335,6 +1362,19 @@ class Scheduler( ...@@ -1335,6 +1362,19 @@ class Scheduler(
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req) self._add_request_to_queue(req)
def handle_batch_embedding_request(
self,
recv_req: BatchTokenizedEmbeddingReqInput,
):
"""Handle optimized batch embedding request."""
logger.debug(
f"Processing batch embedding request with {len(recv_req)} requests"
)
# Process each request in the batch
for tokenized_req in recv_req:
self.handle_embedding_request(tokenized_req)
def self_check_during_idle(self): def self_check_during_idle(self):
self.check_memory() self.check_memory()
self.check_tree_cache() self.check_tree_cache()
...@@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req): ...@@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req):
def is_work_request(recv_req): def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) return isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
BatchTokenizedEmbeddingReqInput,
),
)
def run_scheduler_process( def run_scheduler_process(
......
...@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -768,6 +770,30 @@ class TokenizerManager: ...@@ -768,6 +770,30 @@ class TokenizerManager:
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
return state return state
def _send_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
tokenized_objs: List[
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
],
created_time: Optional[float] = None,
):
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
else:
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
self.send_to_scheduler.send_pyobj(batch_req)
# Create states for each individual request in the batch
for i, tokenized_obj in enumerate(tokenized_objs):
tmp_obj = obj[i]
state = ReqState(
[], False, asyncio.Event(), tmp_obj, created_time=created_time
)
self.rid_to_state[tmp_obj.rid] = state
async def _wait_one_response( async def _wait_one_response(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -870,10 +896,17 @@ class TokenizerManager: ...@@ -870,10 +896,17 @@ class TokenizerManager:
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
for i, tokenized_obj in enumerate(tokenized_objs): # Send as a single batched request
self._send_batch_request(obj, tokenized_objs, created_time)
# Set up generators for each request in the batch
for i in range(batch_size):
tmp_obj = obj[i] tmp_obj = obj[i]
state = self._send_one_request(tmp_obj, tokenized_obj, created_time) generators.append(
generators.append(self._wait_one_response(tmp_obj, state, request)) self._wait_one_response(
tmp_obj, self.rid_to_state[tmp_obj.rid], request
)
)
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# Sequential tokenization and processing # Sequential tokenization and processing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment