"docs/source/api/vscode:/vscode.git/clone" did not exist on "bcff14ccde7956da18a675b5cb30c4774611f866"
Unverified Commit 501dfa6b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove sampling info events and overlap thread file (#11300)

parent 79d34951
...@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -783,16 +783,6 @@ class SchedulerDisaggregationDecodeMixin:
self.prepare_mlp_sync_batch(batch) self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True last_batch_in_queue = True
elif prepare_mlp_sync_flag: elif prepare_mlp_sync_flag:
...@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -806,9 +796,6 @@ class SchedulerDisaggregationDecodeMixin:
# Process the results of the previous batch but skip if the last batch is extend # Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue: if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
queue_size = ( queue_size = (
......
...@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -338,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
result = self.run_batch(batch) result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
if self.last_batch: if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result) self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0: if len(self.disagg_prefill_inflight_queue) > 0:
...@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
if self.enable_overlap: if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal() self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue( def process_disagg_prefill_inflight_queue(
......
...@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -891,7 +891,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None # shape: [b], int64 input_ids: torch.Tensor = None # shape: [b], int64
......
...@@ -1012,22 +1012,9 @@ class Scheduler( ...@@ -1012,22 +1012,9 @@ class Scheduler(
result = self.run_batch(batch) result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch: if self.last_batch:
# Process the results of the last batch # Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
...@@ -2100,7 +2087,7 @@ class Scheduler( ...@@ -2100,7 +2087,7 @@ class Scheduler(
self.record_batch_in_overlap(model_worker_batch) self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward # Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = ( model_worker_batch.sampling_info = (
model_worker_batch.sampling_info.copy_for_forward() model_worker_batch.sampling_info.copy_for_forward()
) )
...@@ -2219,9 +2206,6 @@ class Scheduler( ...@@ -2219,9 +2206,6 @@ class Scheduler(
if self.enable_overlap: if self.enable_overlap:
if result.copy_done is not None: if result.copy_done is not None:
result.copy_done.synchronize() result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal() self.maybe_send_health_check_signal()
...@@ -2431,13 +2415,6 @@ class Scheduler( ...@@ -2431,13 +2415,6 @@ class Scheduler(
self._add_request_to_queue(req) self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
......
...@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin: ...@@ -173,8 +173,6 @@ class SchedulerOutputProcessorMixin:
) )
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
self.set_next_batch_sampling_info_done(batch)
else: # embedding or reward model else: # embedding or reward model
embeddings = result.embeddings.tolist() embeddings = result.embeddings.tolist()
...@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin: ...@@ -295,7 +293,6 @@ class SchedulerOutputProcessorMixin:
self.abort_request(AbortReq(rid=req.rid)) self.abort_request(AbortReq(rid=req.rid))
req.grammar.finished = req.finished() req.grammar.finished = req.finished()
self.set_next_batch_sampling_info_done(batch)
self.stream_output(batch.reqs, batch.return_logprob) self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end() self.token_to_kv_pool_allocator.free_group_end()
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from __future__ import annotations
import dataclasses
import logging
import signal
import threading
from queue import Queue
from typing import TYPE_CHECKING, List, Optional, Tuple
import psutil
import torch
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
SendWeightsToRemoteInstanceReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
# Init future mappings
self.future_map = FutureMap(self.max_running_requests, self.device)
# Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
self.output_queue = Queue()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
self.hicache_layer_transfer_counter = counter
def get_worker_info(self):
return self.worker.get_worker_info()
def get_tokens_per_layer_info(self):
return self.worker.get_tokens_per_layer_info()
@property
def sliding_window_size(self) -> Optional[int]:
return self.worker.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.worker.is_hybrid
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
def get_tp_group(self):
return self.worker.get_tp_group()
def get_attention_tp_group(self):
return self.worker.get_attention_tp_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool_allocator,
)
def get_kv_cache(self):
return self.worker.model_runner.token_to_kv_pool
def forward_thread_func(self):
try:
with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)
@DynamicGradMode()
def forward_thread_func_(self):
batch_pt = 0
batch_lists: List = [None] * 2
while True:
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
if not model_worker_batch:
break
sync_event.wait()
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1
# Create event
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
self.future_map.resolve_future(model_worker_batch)
# Run forward
forward_batch_output = self.worker.forward_batch_generation(
model_worker_batch,
model_worker_batch.launch_done,
)
logits_output, next_token_ids, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
)
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
# store the future indices into future map
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
# Copy results to the CPU
if model_worker_batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True
)
# Only copy to CPU if not already on CPU
if next_token_ids.device.type != "cpu":
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()
self.output_queue.put(
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
)
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
self.output_queue.get()
)
if launch_done is not None:
launch_done.wait()
copy_done.synchronize()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids, can_run_cuda_graph
def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
model_worker_batch.sampling_info = self.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
sync_event = torch.get_device_module(self.device).Event()
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
# get this forward batch's future token ids
future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
return ForwardBatchOutput(
next_token_ids=future_next_token_ids,
can_run_cuda_graph=False,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)
return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.worker.init_weights_update_group(recv_req)
return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.worker.destroy_weights_update_group(recv_req)
return success, message
def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
):
success, message = self.worker.init_weights_send_group_for_remote_instance(
recv_req
)
return success, message
def send_weights_to_remote_instance(
self, recv_req: SendWeightsToRemoteInstanceReqInput
):
success, message = self.worker.send_weights_to_remote_instance(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.worker.update_weights_from_tensor(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.worker.can_run_lora_batch(lora_ids)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
...@@ -75,10 +75,6 @@ class ForwardMode(IntEnum): ...@@ -75,10 +75,6 @@ class ForwardMode(IntEnum):
# Used in speculative decoding: extend a batch in the draft model. # Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND = auto() DRAFT_EXTEND = auto()
# A dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto()
# Split Prefill for PD multiplexing # Split Prefill for PD multiplexing
SPLIT_PREFILL = auto() SPLIT_PREFILL = auto()
...@@ -128,9 +124,6 @@ class ForwardMode(IntEnum): ...@@ -128,9 +124,6 @@ class ForwardMode(IntEnum):
def is_cpu_graph(self): def is_cpu_graph(self):
return self == ForwardMode.DECODE return self == ForwardMode.DECODE
def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST
def is_split_prefill(self): def is_split_prefill(self):
return self == ForwardMode.SPLIT_PREFILL return self == ForwardMode.SPLIT_PREFILL
......
...@@ -2057,15 +2057,11 @@ class ModelRunner: ...@@ -2057,15 +2057,11 @@ class ModelRunner:
def _preprocess_logits( def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
): ):
# Apply logit bias # NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
if sampling_info.sampling_info_done: # was executed after we processed last batch's results.
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch. # Calculate logits bias and apply it to next_token_logits.
if sampling_info.grammars: sampling_info.update_regex_vocab_mask()
sampling_info.sampling_info_done.wait()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.apply_logits_bias(logits_output.next_token_logits) sampling_info.apply_logits_bias(logits_output.next_token_logits)
def sample( def sample(
......
...@@ -44,12 +44,9 @@ class SamplingBatchInfo: ...@@ -44,12 +44,9 @@ class SamplingBatchInfo:
vocab_mask: Optional[torch.Tensor] = None vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
# An event used for overlap schedule
sampling_info_done: Optional[threading.Event] = None
# Penalizer # Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
linear_penalty: torch.Tensor = None acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
# Whether any request has custom logit processor # Whether any request has custom logit processor
has_custom_logit_processor: bool = False has_custom_logit_processor: bool = False
...@@ -217,19 +214,19 @@ class SamplingBatchInfo: ...@@ -217,19 +214,19 @@ class SamplingBatchInfo:
def update_penalties(self): def update_penalties(self):
if self.penalizer_orchestrator.is_required: if self.penalizer_orchestrator.is_required:
self.linear_penalty = torch.zeros( self.acc_linear_penalties = torch.zeros(
(len(self.temperatures), self.vocab_size), (len(self.temperatures), self.vocab_size),
dtype=torch.float32, dtype=torch.float32,
device=self.temperatures.device, device=self.temperatures.device,
) )
self.penalizer_orchestrator.apply(self.linear_penalty) self.penalizer_orchestrator.apply(self.acc_linear_penalties)
else: else:
self.linear_penalty = None self.acc_linear_penalties = None
def apply_logits_bias(self, logits: torch.Tensor): def apply_logits_bias(self, logits: torch.Tensor):
if self.linear_penalty is not None: if self.acc_linear_penalties is not None:
# Used in the overlap mode # Used in the overlap mode
logits.add_(self.linear_penalty) logits.add_(self.acc_linear_penalties)
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
# Used in the non-overlap mode # Used in the non-overlap mode
...@@ -373,11 +370,7 @@ class SamplingBatchInfo: ...@@ -373,11 +370,7 @@ class SamplingBatchInfo:
def copy_for_forward(self): def copy_for_forward(self):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later # Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self.update_penalties() self.update_penalties()
return dataclasses.replace( return dataclasses.replace(self, penalizer_orchestrator=None)
self,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
)
def merge_bias_tensor( def merge_bias_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment