Unverified Commit 0ce84c82 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support colocating requests (#7973)

parent 59d0bf01
......@@ -26,6 +26,7 @@ import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
......@@ -282,6 +283,9 @@ class DataParallelController:
),
):
self.dispatching(recv_req)
elif isinstance(recv_req, BlockReqInput):
for worker in self.workers:
worker.send_pyobj(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
......
......@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
class BlockReqType(Enum):
BLOCK = 1
UNBLOCK = 2
@dataclass
class BlockReqInput:
type: BlockReqType
......@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
......@@ -504,6 +505,12 @@ class Scheduler(
)
self.init_profier()
self.input_blocker = (
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
else None
)
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config)
......@@ -1035,6 +1042,9 @@ class Scheduler(
else:
recv_reqs = None
if self.input_blocker is not None:
recv_reqs = self.input_blocker.handle(recv_reqs)
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, List, Optional
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
from sglang.srt.poll_based_barrier import PollBasedBarrier
logger = logging.getLogger(__name__)
class SchedulerInputBlocker:
def __init__(self, noop: bool):
self._state = _State.UNBLOCKED
self._pending_reqs = []
self._noop = noop
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
def handle(self, recv_reqs: Optional[List[Any]]):
assert (recv_reqs is None) == self._noop
if not self._noop:
output_reqs = []
for recv_req in recv_reqs:
output_reqs += self._handle_recv_req(recv_req)
global_arrived_unblock_barrier = (
self._global_unblock_barrier.poll_global_arrived()
)
if (
self._state == _State.GLOBAL_UNBLOCK_BARRIER
and global_arrived_unblock_barrier
):
output_reqs += self._handle_arrive_unblock_barrier()
if not self._noop:
return output_reqs
def _handle_recv_req(self, recv_req):
if isinstance(recv_req, BlockReqInput):
if recv_req.type == BlockReqType.BLOCK:
self._execute_block_req()
return []
elif recv_req.type == BlockReqType.UNBLOCK:
self._execute_unblock_req()
return []
else:
raise NotImplementedError(f"{recv_req=}")
else:
if self._state == _State.UNBLOCKED:
return [recv_req]
else:
self._pending_reqs.append(recv_req)
return []
def _execute_block_req(self):
logger.info("Handle block req")
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
def _execute_unblock_req(self):
logger.info("Handle unblock req")
self._change_state(
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
)
self._global_unblock_barrier.local_arrive()
def _handle_arrive_unblock_barrier(self):
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
self._change_state(
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
)
output_reqs = [*self._pending_reqs]
self._pending_reqs.clear()
return output_reqs
def _change_state(self, original: "_State", target: "_State"):
assert self._state == original, f"{self._state=} {original=} {target=}"
self._state = target
class _State(Enum):
UNBLOCKED = auto()
BLOCKED = auto()
GLOBAL_UNBLOCK_BARRIER = auto()
@contextmanager
def input_blocker_guard_region(send_to_scheduler):
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
try:
yield
finally:
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
......@@ -27,6 +27,7 @@ import threading
import time
import uuid
from collections import deque
from contextlib import nullcontext
from datetime import datetime
from http import HTTPStatus
from typing import (
......@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BlockReqType,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
......@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -819,12 +822,21 @@ class TokenizerManager:
rids.append(tmp_obj.rid)
else:
# Sequential tokenization and processing
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, state, request))
rids.append(tmp_obj.rid)
with (
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
else nullcontext()
):
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
state = self._send_one_request(
tmp_obj, tokenized_obj, created_time
)
generators.append(
self._wait_one_response(tmp_obj, state, request)
)
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
......
import torch
from sglang.srt.distributed import get_world_group
class PollBasedBarrier:
def __init__(self, noop: bool = False):
self._noop = noop
self._local_arrived = False
def local_arrive(self):
assert not self._local_arrived
self._local_arrived = True
def poll_global_arrived(self) -> bool:
global_arrived = self._compute_global_arrived()
output = self._local_arrived and global_arrived
if output:
self._local_arrived = False
return output
def _compute_global_arrived(self) -> bool:
local_arrived = self._noop or self._local_arrived
global_arrived = torch.tensor(local_arrived)
# Can optimize if bottleneck
torch.distributed.all_reduce(
global_arrived,
torch.distributed.ReduceOp.MIN,
group=get_world_group().cpu_group,
)
return global_arrived.item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment