Unverified Commit 0ce84c82 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support colocating requests (#7973)

parent 59d0bf01
...@@ -26,6 +26,7 @@ import zmq ...@@ -26,6 +26,7 @@ import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BlockReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
...@@ -282,6 +283,9 @@ class DataParallelController: ...@@ -282,6 +283,9 @@ class DataParallelController:
), ),
): ):
self.dispatching(recv_req) self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else: else:
# Send other control messages to first worker of tp group # Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]: for worker in self.workers[:: self.control_message_step]:
......
...@@ -1103,3 +1103,13 @@ class LoRAUpdateResult: ...@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
@dataclass
class BlockReqInput:
type: BlockReqType
...@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import ( ...@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder, PrefillAdder,
SchedulePolicy, SchedulePolicy,
) )
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_output_processor_mixin import ( from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin, SchedulerOutputProcessorMixin,
) )
...@@ -504,6 +505,12 @@ class Scheduler( ...@@ -504,6 +505,12 @@ class Scheduler(
) )
self.init_profier() self.init_profier()
self.input_blocker = (
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
else None
)
# Init metrics stats # Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
...@@ -1035,6 +1042,9 @@ class Scheduler( ...@@ -1035,6 +1042,9 @@ class Scheduler(
else: else:
recv_reqs = None recv_reqs = None
if self.input_blocker is not None:
recv_reqs = self.input_blocker.handle(recv_reqs)
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
work_reqs = [ work_reqs = [
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, List, Optional
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
from sglang.srt.poll_based_barrier import PollBasedBarrier
logger = logging.getLogger(__name__)
class SchedulerInputBlocker:
def __init__(self, noop: bool):
self._state = _State.UNBLOCKED
self._pending_reqs = []
self._noop = noop
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
def handle(self, recv_reqs: Optional[List[Any]]):
assert (recv_reqs is None) == self._noop
if not self._noop:
output_reqs = []
for recv_req in recv_reqs:
output_reqs += self._handle_recv_req(recv_req)
global_arrived_unblock_barrier = (
self._global_unblock_barrier.poll_global_arrived()
)
if (
self._state == _State.GLOBAL_UNBLOCK_BARRIER
and global_arrived_unblock_barrier
):
output_reqs += self._handle_arrive_unblock_barrier()
if not self._noop:
return output_reqs
def _handle_recv_req(self, recv_req):
if isinstance(recv_req, BlockReqInput):
if recv_req.type == BlockReqType.BLOCK:
self._execute_block_req()
return []
elif recv_req.type == BlockReqType.UNBLOCK:
self._execute_unblock_req()
return []
else:
raise NotImplementedError(f"{recv_req=}")
else:
if self._state == _State.UNBLOCKED:
return [recv_req]
else:
self._pending_reqs.append(recv_req)
return []
def _execute_block_req(self):
logger.info("Handle block req")
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
def _execute_unblock_req(self):
logger.info("Handle unblock req")
self._change_state(
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
)
self._global_unblock_barrier.local_arrive()
def _handle_arrive_unblock_barrier(self):
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
self._change_state(
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
)
output_reqs = [*self._pending_reqs]
self._pending_reqs.clear()
return output_reqs
def _change_state(self, original: "_State", target: "_State"):
assert self._state == original, f"{self._state=} {original=} {target=}"
self._state = target
class _State(Enum):
UNBLOCKED = auto()
BLOCKED = auto()
GLOBAL_UNBLOCK_BARRIER = auto()
@contextmanager
def input_blocker_guard_region(send_to_scheduler):
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
try:
yield
finally:
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
...@@ -27,6 +27,7 @@ import threading ...@@ -27,6 +27,7 @@ import threading
import time import time
import uuid import uuid
from collections import deque from collections import deque
from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
...@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -819,12 +822,21 @@ class TokenizerManager: ...@@ -819,12 +822,21 @@ class TokenizerManager:
rids.append(tmp_obj.rid) rids.append(tmp_obj.rid)
else: else:
# Sequential tokenization and processing # Sequential tokenization and processing
for i in range(batch_size): with (
tmp_obj = obj[i] input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
tokenized_obj = await self._tokenize_one_request(tmp_obj) if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
state = self._send_one_request(tmp_obj, tokenized_obj, created_time) else nullcontext()
generators.append(self._wait_one_response(tmp_obj, state, request)) ):
rids.append(tmp_obj.rid) for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
state = self._send_one_request(
tmp_obj, tokenized_obj, created_time
)
generators.append(
self._wait_one_response(tmp_obj, state, request)
)
rids.append(tmp_obj.rid)
else: else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal. # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128: if batch_size > 128:
......
import torch
from sglang.srt.distributed import get_world_group
class PollBasedBarrier:
def __init__(self, noop: bool = False):
self._noop = noop
self._local_arrived = False
def local_arrive(self):
assert not self._local_arrived
self._local_arrived = True
def poll_global_arrived(self) -> bool:
global_arrived = self._compute_global_arrived()
output = self._local_arrived and global_arrived
if output:
self._local_arrived = False
return output
def _compute_global_arrived(self) -> bool:
local_arrived = self._noop or self._local_arrived
global_arrived = torch.tensor(local_arrived)
# Can optimize if bottleneck
torch.distributed.all_reduce(
global_arrived,
torch.distributed.ReduceOp.MIN,
group=get_world_group().cpu_group,
)
return global_arrived.item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment