Unverified Commit 501dfa6b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove sampling info events and overlap thread file (#11300)

parent 79d34951
......@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True
elif prepare_mlp_sync_flag:
......@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin:
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
queue_size = (
......
......@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
......@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue(
......
......@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
input_ids: torch.Tensor = None # shape: [b], int64
......
......@@ -1012,22 +1012,9 @@ class Scheduler(
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, do self-check and re-init some states
......@@ -2100,7 +2087,7 @@ class Scheduler(
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
......@@ -2219,9 +2206,6 @@ class Scheduler(
if self.enable_overlap:
if result.copy_done is not None:
result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
......@@ -2431,13 +2415,6 @@ class Scheduler(
self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
......
......@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin:
)
logprob_pt += num_input_logprobs
self.set_next_batch_sampling_info_done(batch)
else: # embedding or reward model
embeddings = result.embeddings.tolist()
......@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin:
self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished()
self.set_next_batch_sampling_info_done(batch)
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end()
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from __future__ import annotations
import dataclasses
import logging
import signal
import threading
from queue import Queue
from typing import TYPE_CHECKING, List, Optional, Tuple
import psutil
import torch
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
# Init future mappings
self.future_map = FutureMap(self.max_running_requests, self.device)
# Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
self.output_queue = Queue()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter
def get_worker_info(self):
return self.worker.get_worker_info()
def get_tokens_per_layer_info(self):
return self.worker.get_tokens_per_layer_info()
@property
def sliding_window_size(self) -> Optional[int]:
return self.worker.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.worker.is_hybrid
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
def get_tp_group(self):
return self.worker.get_tp_group()
def get_attention_tp_group(self):
return self.worker.get_attention_tp_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool_allocator,
)
def get_kv_cache(self):
return self.worker.model_runner.token_to_kv_pool
def forward_thread_func(self):
try:
with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)
@DynamicGradMode()
def forward_thread_func_(self):
batch_pt = 0
batch_lists: List = [None] * 2
while True:
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
if not model_worker_batch:
break
sync_event.wait()
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1
# Create event
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
self.future_map.resolve_future(model_worker_batch)
# Run forward
forward_batch_output = self.worker.forward_batch_generation(
model_worker_batch,
model_worker_batch.launch_done,
)
logits_output, next_token_ids, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
)
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
# store the future indices into future map
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
# Copy results to the CPU
if model_worker_batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True
)
# Only copy to CPU if not already on CPU
if next_token_ids.device.type != "cpu":
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()
self.output_queue.put(
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
)
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
self.output_queue.get()
)
if launch_done is not None:
launch_done.wait()
copy_done.synchronize()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids, can_run_cuda_graph
def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch.sampling_info = self.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
sync_event = torch.get_device_module(self.device).Event()
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
# get this forward batch's future token ids
future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
return ForwardBatchOutput(
next_token_ids=future_next_token_ids,
can_run_cuda_graph=False,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.worker.init_weights_update_group(recv_req)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.worker.destroy_weights_update_group(recv_req)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = self.worker.init_weights_send_group_for_remote_instance(
recv_req
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.worker.send_weights_to_remote_instance(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.worker.update_weights_from_tensor(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.worker.can_run_lora_batch(lora_ids)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
......@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
# Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND = auto()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
......@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
def is_cpu_graph(self):
return self == ForwardMode.DECODE
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
def is_split_prefill(self):
return self == ForwardMode.SPLIT_PREFILL
......
......@@ -2057,15 +2057,11 @@ class ModelRunner:
def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
):
# Apply logit bias
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info.grammars:
sampling_info.sampling_info_done.wait()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
# was executed after we processed last batch's results.
# Calculate logits bias and apply it to next_token_logits.
sampling_info.update_regex_vocab_mask()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def sample(
......
......@@ -44,12 +44,9 @@ class SamplingBatchInfo:
vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
# An event used for overlap schedule
sampling_info_done: Optional[threading.Event] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
linear_penalty: torch.Tensor = None
acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
# Whether any request has custom logit processor
has_custom_logit_processor: bool = False
......@@ -217,19 +214,19 @@ class SamplingBatchInfo:
def update_penalties(self):
if self.penalizer_orchestrator.is_required:
self.linear_penalty = torch.zeros(
self.acc_linear_penalties = torch.zeros(
(len(self.temperatures), self.vocab_size),
dtype=torch.float32,
device=self.temperatures.device,
)
self.penalizer_orchestrator.apply(self.linear_penalty)
self.penalizer_orchestrator.apply(self.acc_linear_penalties)
else:
self.linear_penalty = None
self.acc_linear_penalties = None
def apply_logits_bias(self, logits: torch.Tensor):
if self.linear_penalty is not None:
if self.acc_linear_penalties is not None:
# Used in the overlap mode
logits.add_(self.linear_penalty)
logits.add_(self.acc_linear_penalties)
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
# Used in the non-overlap mode
......@@ -373,11 +370,7 @@ class SamplingBatchInfo:
def copy_for_forward(self):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self.update_penalties()
return dataclasses.replace(
self,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
)
return dataclasses.replace(self, penalizer_orchestrator=None)
def merge_bias_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment