Unverified Commit 41efcaeb authored by ykcombat's avatar ykcombat Committed by GitHub
Browse files

[Feature] PD-Multiplexing Context and Scheduler, lazy import spatial. (#12275)

parent 70562969
...@@ -134,10 +134,7 @@ class LogitsMetadata: ...@@ -134,10 +134,7 @@ class LogitsMetadata:
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
if ( if (
(
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
or forward_batch.forward_mode.is_split_prefill()
)
and forward_batch.return_logprob and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
): ):
...@@ -384,8 +381,8 @@ class LogitsProcessor(nn.Module): ...@@ -384,8 +381,8 @@ class LogitsProcessor(nn.Module):
input_logprob_indices = None input_logprob_indices = None
elif ( elif (
logits_metadata.forward_mode.is_extend() logits_metadata.forward_mode.is_extend()
or logits_metadata.forward_mode.is_split_prefill() and not logits_metadata.extend_return_logprob
) and not logits_metadata.extend_return_logprob: ):
# Prefill without input logprobs. # Prefill without input logprobs.
if logits_metadata.padded_static_len < 0: if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
......
...@@ -72,7 +72,11 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool ...@@ -72,7 +72,11 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs, get_global_server_args from sglang.srt.server_args import ServerArgs, get_global_server_args
......
...@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache ...@@ -152,6 +152,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -213,6 +214,7 @@ class Scheduler( ...@@ -213,6 +214,7 @@ class Scheduler(
SchedulerMetricsMixin, SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin, SchedulerDisaggregationPrefillMixin,
SchedulerMultiplexMixin,
SchedulerRuntimeCheckerMixin, SchedulerRuntimeCheckerMixin,
SchedulerPPMixin, SchedulerPPMixin,
): ):
...@@ -252,6 +254,7 @@ class Scheduler( ...@@ -252,6 +254,7 @@ class Scheduler(
self.enable_lora = server_args.enable_lora self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.enable_pdmux = server_args.enable_pdmux
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.enable_metrics_for_all_schedulers = ( self.enable_metrics_for_all_schedulers = (
...@@ -285,6 +288,10 @@ class Scheduler( ...@@ -285,6 +288,10 @@ class Scheduler(
# Init inter-process communication # Init inter-process communication
self.init_sockets(server_args, port_args) self.init_sockets(server_args, port_args)
# Init pdmux context
if self.enable_pdmux:
self.init_pdmux()
# Init tokenizer # Init tokenizer
self.init_tokenizer() self.init_tokenizer()
...@@ -424,6 +431,8 @@ class Scheduler( ...@@ -424,6 +431,8 @@ class Scheduler(
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch # The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None self.cur_batch: Optional[ScheduleBatch] = None
# The current split prefill batch
self.split_prefill_batch: Optional[ScheduleBatch] = None
# The last forward batch # The last forward batch
self.last_batch: Optional[ScheduleBatch] = None self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0 self.forward_ct = 0
...@@ -1952,7 +1961,6 @@ class Scheduler( ...@@ -1952,7 +1961,6 @@ class Scheduler(
# Run forward # Run forward
if self.is_generation: if self.is_generation:
batch_or_worker_batch = batch batch_or_worker_batch = batch
if self.enable_overlap or self.spec_algorithm.is_none(): if self.enable_overlap or self.spec_algorithm.is_none():
...@@ -2009,6 +2017,9 @@ class Scheduler( ...@@ -2009,6 +2017,9 @@ class Scheduler(
# The future value, usually for next batch preparation # The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens # Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens batch.seq_lens = batch_result.next_draft_input.new_seq_lens
elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
future_indices_or_next_token_ids = batch_result.next_token_ids
else: else:
batch_result = self.model_worker.forward_batch_generation( batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch batch_or_worker_batch
...@@ -2791,7 +2802,9 @@ def run_scheduler_process( ...@@ -2791,7 +2802,9 @@ def run_scheduler_process(
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL: if disaggregation_mode == DisaggregationMode.NULL:
if server_args.pp_size > 1: if scheduler.enable_pdmux:
scheduler.event_loop_pdmux()
elif server_args.pp_size > 1:
scheduler.event_loop_pp() scheduler.event_loop_pp()
elif scheduler.enable_overlap: elif scheduler.enable_overlap:
scheduler.event_loop_overlap() scheduler.event_loop_overlap()
......
...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromIPCReqInput, UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
...@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker): ...@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker):
pp_hidden_states_proxy_tensors=pp_proxy_tensors, pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
) )
def forward_batch_split_prefill(self, batch: ScheduleBatch):
if batch.split_index == 0:
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
batch.split_forward_batch = forward_batch
batch.seq_lens_cpu_cache = model_worker_batch.seq_lens_cpu
else:
model_worker_batch = batch.get_model_worker_batch(batch.seq_lens_cpu_cache)
logits_output, can_run_cuda_graph = self.model_runner.forward(
batch.split_forward_batch, split_forward_count=batch.split_forward_count
)
if logits_output:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
else:
next_token_ids = None
batch_result = GenerationBatchResult(
logits_output=logits_output,
can_run_cuda_graph=can_run_cuda_graph,
)
batch_result.next_token_ids = next_token_ids
return batch_result
...@@ -509,6 +509,7 @@ class MHATokenToKVPool(KVCache): ...@@ -509,6 +509,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool, enable_memory_saver: bool,
start_layer: Optional[int] = None, start_layer: Optional[int] = None,
end_layer: Optional[int] = None, end_layer: Optional[int] = None,
enable_alt_stream: bool = True,
enable_kv_cache_copy: bool = False, enable_kv_cache_copy: bool = False,
): ):
super().__init__( super().__init__(
...@@ -527,7 +528,9 @@ class MHATokenToKVPool(KVCache): ...@@ -527,7 +528,9 @@ class MHATokenToKVPool(KVCache):
self._create_buffers() self._create_buffers()
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None self.alt_stream = (
self.device_module.Stream() if _is_cuda and enable_alt_stream else None
)
if enable_kv_cache_copy: if enable_kv_cache_copy:
self._init_kv_copy_and_warmup() self._init_kv_copy_and_warmup()
......
...@@ -96,6 +96,7 @@ class ForwardMode(IntEnum): ...@@ -96,6 +96,7 @@ class ForwardMode(IntEnum):
else False else False
) )
or self == ForwardMode.TARGET_VERIFY or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.SPLIT_PREFILL
) )
def is_decode(self): def is_decode(self):
......
...@@ -1765,6 +1765,7 @@ class ModelRunner: ...@@ -1765,6 +1765,7 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer, start_layer=self.start_layer,
end_layer=self.end_layer, end_layer=self.end_layer,
enable_alt_stream=not self.server_args.enable_pdmux,
enable_kv_cache_copy=( enable_kv_cache_copy=(
self.server_args.speculative_algorithm is not None self.server_args.speculative_algorithm is not None
), ),
...@@ -1833,12 +1834,18 @@ class ModelRunner: ...@@ -1833,12 +1834,18 @@ class ModelRunner:
def init_attention_backend(self): def init_attention_backend(self):
"""Init attention kernel backend.""" """Init attention kernel backend."""
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker: if self.server_args.enable_pdmux:
self.attn_backend = self._get_attention_backend(init_new_workspace=True)
self.decode_attn_backend_group = []
for _ in range(self.server_args.sm_group_num):
self.decode_attn_backend_group.append(self._get_attention_backend())
self.decode_attn_backend = self.decode_attn_backend_group[0]
elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend) self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
else: else:
self.attn_backend = self._get_attention_backend() self.attn_backend = self._get_attention_backend()
def _get_attention_backend(self): def _get_attention_backend(self, init_new_workspace: bool = False):
"""Init attention kernel backend.""" """Init attention kernel backend."""
self.prefill_attention_backend_str, self.decode_attention_backend_str = ( self.prefill_attention_backend_str, self.decode_attention_backend_str = (
self.server_args.get_attention_backends() self.server_args.get_attention_backends()
...@@ -1852,10 +1859,12 @@ class ModelRunner: ...@@ -1852,10 +1859,12 @@ class ModelRunner:
attn_backend = HybridAttnBackend( attn_backend = HybridAttnBackend(
self, self,
decode_backend=self._get_attention_backend_from_str( decode_backend=self._get_attention_backend_from_str(
self.decode_attention_backend_str self.decode_attention_backend_str,
init_new_workspace=init_new_workspace,
), ),
prefill_backend=self._get_attention_backend_from_str( prefill_backend=self._get_attention_backend_from_str(
self.prefill_attention_backend_str self.prefill_attention_backend_str,
init_new_workspace=init_new_workspace,
), ),
) )
logger.info( logger.info(
...@@ -1869,7 +1878,8 @@ class ModelRunner: ...@@ -1869,7 +1878,8 @@ class ModelRunner:
) )
else: else:
attn_backend = self._get_attention_backend_from_str( attn_backend = self._get_attention_backend_from_str(
self.server_args.attention_backend self.server_args.attention_backend,
init_new_workspace=init_new_workspace,
) )
( (
...@@ -1878,9 +1888,12 @@ class ModelRunner: ...@@ -1878,9 +1888,12 @@ class ModelRunner:
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str) ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
return attn_backend return attn_backend
def _get_attention_backend_from_str(self, backend_str: str): def _get_attention_backend_from_str(
self, backend_str: str, init_new_workspace: bool = False
):
if backend_str not in ATTENTION_BACKENDS: if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}") raise ValueError(f"Invalid attention backend: {backend_str}")
self.init_new_workspace = init_new_workspace
full_attention_backend = ATTENTION_BACKENDS[backend_str](self) full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend) return attn_backend_wrapper(self, full_attention_backend)
...@@ -1978,6 +1991,9 @@ class ModelRunner: ...@@ -1978,6 +1991,9 @@ class ModelRunner:
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh) tensor_parallel(self.model, device_mesh)
def update_decode_attn_backend(self, stream_idx: int):
self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
def forward_decode( def forward_decode(
self, self,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
...@@ -1985,6 +2001,10 @@ class ModelRunner: ...@@ -1985,6 +2001,10 @@ class ModelRunner:
pp_proxy_tensors=None, pp_proxy_tensors=None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
if not skip_attn_backend_init: if not skip_attn_backend_init:
if self.server_args.enable_pdmux:
self.decode_attn_backend.init_forward_metadata(forward_batch)
forward_batch.attn_backend = self.decode_attn_backend
else:
self.attn_backend.init_forward_metadata(forward_batch) self.attn_backend.init_forward_metadata(forward_batch)
# FIXME: add pp_proxy_tensors arg to all models # FIXME: add pp_proxy_tensors arg to all models
kwargs = {} kwargs = {}
...@@ -2123,18 +2143,18 @@ class ModelRunner: ...@@ -2123,18 +2143,18 @@ class ModelRunner:
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_extend():
ret = self.forward_extend(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_split_prefill(): elif forward_batch.forward_mode.is_split_prefill():
ret = self.forward_split_prefill( ret = self.forward_split_prefill(
forward_batch, forward_batch,
reinit_attn_backend=reinit_attn_backend, reinit_attn_backend=reinit_attn_backend,
forward_count=split_forward_count, forward_count=split_forward_count,
) )
elif forward_batch.forward_mode.is_extend():
ret = self.forward_extend(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors,
)
elif forward_batch.forward_mode.is_idle(): elif forward_batch.forward_mode.is_idle():
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else: else:
......
"""
Mixin class providing multiplexing scheduling logic
"""
import logging
import torch
import torch.distributed as dist
from torch.cuda.streams import ExternalStream
from sglang.srt.distributed.parallel_state import set_pdmux_status
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.multiplex.pdmux_context import (
get_current_stream_idx,
get_sm_counts,
get_stream_groups,
initialize_stream_groups,
load_pdmux_config,
set_current_stream_idx,
)
logger = logging.getLogger(__name__)
class SchedulerMultiplexMixin:
def init_pdmux(self):
# for pd_multiplexing, Init stream_groups, exclude normal stream for prefill only and decode only
self.pdmux_config = load_pdmux_config(self.server_args.pdmux_config_path)
initialize_stream_groups(self.gpu_id, self.pdmux_config)
self.stream_groups = get_stream_groups()
self.sm_counts = get_sm_counts()
self.real_sm_group_num = len(self.stream_groups)
logger.info(
f"PD-Multiplexing enabled with {self.real_sm_group_num} stream groups, sm_counts (prefill_sm, decode_sm): {self.sm_counts}"
)
# TODO(jason-fxz): This is a temporary demo
def adjust_stream_groups(self) -> tuple[int, tuple[ExternalStream, ExternalStream]]:
if not self.running_batch.is_empty() and self.split_prefill_batch:
decode_bs = self.running_batch.batch_size()
manual_divisions = self.pdmux_config.manual_divisions
if manual_divisions:
for i in range(len(manual_divisions)):
_, _, threshold = manual_divisions[i]
if decode_bs >= threshold:
stream_idx = i + 1
else:
stream_idx = max(
1,
min(
self.real_sm_group_num - 2,
decode_bs
* (self.real_sm_group_num - 2)
// self.pdmux_config.decode_bs_divisor,
),
)
set_current_stream_idx(stream_idx)
elif not self.running_batch.is_empty():
set_current_stream_idx(self.real_sm_group_num - 1)
else:
set_current_stream_idx(0)
stream_idx = get_current_stream_idx()
self.tp_worker.model_runner.update_decode_attn_backend(stream_idx)
return stream_idx, self.stream_groups[stream_idx]
def update_split_prefill_batch(self, sm_count: int) -> bool:
if self.split_prefill_batch:
return False
# add new request
batch = self.get_new_batch_prefill()
if batch and not batch.is_empty():
batch.forward_mode = (
ForwardMode.SPLIT_PREFILL
) # Set forward mode for split prefill
self.split_prefill_batch = batch
return True
return False
@torch.inference_mode()
def event_loop_pdmux(self):
"""A scheduler loop for pd multiplexing."""
decode_done = False
prefill_done = False
wait_prefill_kernel_done = False
adjust_stream_group = False
stream_idx = get_current_stream_idx()
stream_group = self.stream_groups[stream_idx]
prefill_stream = stream_group[0]
decode_stream = stream_group[1]
torch.cuda.empty_cache()
logger.debug("Starting event loop for pd multiplexing...")
while True:
with torch.cuda.stream(decode_stream):
set_pdmux_status(False)
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
with torch.cuda.stream(prefill_stream):
set_pdmux_status(True)
sm_count = self.sm_counts[stream_idx][0]
if not wait_prefill_kernel_done:
adjust_stream_group = (
self.update_split_prefill_batch(sm_count) or adjust_stream_group
)
with torch.cuda.stream(decode_stream):
set_pdmux_status(False)
self.running_batch = self.update_running_batch(self.running_batch)
adjust_stream_group = adjust_stream_group or (
stream_idx > 0 and self.running_batch.is_empty()
)
if self.running_batch.is_empty() and self.split_prefill_batch is None:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
if adjust_stream_group:
prefill_stream.synchronize()
decode_stream.synchronize()
stream_idx, stream_group = self.adjust_stream_groups()
prefill_stream = stream_group[0]
decode_stream = stream_group[1]
adjust_stream_group = False
logger.debug(
f"Adjusting stream groups: {stream_idx}, prefill sm: {self.sm_counts[stream_idx][0]}, decode sm: {self.sm_counts[stream_idx][1]}"
)
with torch.cuda.stream(decode_stream):
set_pdmux_status(False)
# process decode batch
if self.running_batch and not self.running_batch.is_empty():
decode_result = self.run_batch(self.running_batch)
decode_done = True
else:
decode_done = False
with torch.cuda.stream(prefill_stream):
set_pdmux_status(True)
if (
self.split_prefill_batch
and not self.split_prefill_batch.is_empty()
and not wait_prefill_kernel_done
):
prefill_done = True
forward_count = (
max(
1,
self.pdmux_config.split_forward_token_budget
// self.split_prefill_batch.extend_num_tokens,
)
if self.split_prefill_batch.extend_num_tokens > 0
else self.model_config.num_hidden_layers
)
next_split_index = min(
self.split_prefill_batch.split_index + forward_count,
self.model_config.num_hidden_layers,
)
forward_count = (
next_split_index - self.split_prefill_batch.split_index
)
self.split_prefill_batch.split_forward_count = forward_count
prefill_result = self.run_batch(self.split_prefill_batch)
if next_split_index == self.model_config.num_hidden_layers:
self.split_prefill_batch.split_prefill_finished = True
prefill_exe_done = prefill_stream.record_event()
self.split_prefill_batch.split_index = next_split_index
elif wait_prefill_kernel_done:
prefill_done = True
else:
prefill_done = False
with torch.cuda.stream(decode_stream):
set_pdmux_status(False)
decode_stream.synchronize()
if decode_done:
self.process_batch_result(self.running_batch, decode_result)
with torch.cuda.stream(prefill_stream):
set_pdmux_status(True)
if prefill_done and self.split_prefill_batch.split_prefill_finished:
wait_prefill_kernel_done = True
prefill_exe_done_flag = prefill_exe_done.query()
flags = (
torch.ones(1, device="cpu", dtype=torch.int32)
if prefill_exe_done_flag
else torch.zeros(1, device="cpu", dtype=torch.int32)
)
self.tp_cpu_group.allreduce(flags, dist.ReduceOp.SUM).wait()
if flags.item() == self.tp_size:
self.process_batch_result(
self.split_prefill_batch, prefill_result
)
if self.running_batch and not self.running_batch.is_empty():
self.running_batch.merge_batch(self.split_prefill_batch)
else:
self.running_batch = self.split_prefill_batch
self.split_prefill_batch = None
wait_prefill_kernel_done = False
adjust_stream_group = True
from dataclasses import dataclass, field
from typing import List
import torch
import yaml
STREAM_GROUPS = []
SM_COUNTS = []
SM_GROUP_NUM = 8 # Default number of SM groups
CURRENT_STREAM_IDX = 0
CURRENT_STREAM_GROUP = None
@dataclass
class PDMuxConfig:
sm_group_num: int = 8
manual_divisions: List[List[int]] = field(
default_factory=list
) # [prefill_sm, decode_sm, decode_bs_threshold]
split_forward_token_budget: int = 65536
decode_bs_divisor: int = 36
def load_pdmux_config(config_path: str) -> PDMuxConfig:
"""Load pdmux configuration from YAML file into a dataclass."""
if not config_path:
return PDMuxConfig()
with open(config_path, "r") as f:
raw = yaml.safe_load(f)
if "sm_group_num" not in raw:
raise ValueError("Missing required field: sm_group_num")
if raw["sm_group_num"] < 3:
raise ValueError("sm_group_num must greater than 3")
manual_divisions = raw.get("manual_divisions", [])
expected = raw["sm_group_num"] - 2
if manual_divisions and len(manual_divisions) != expected:
raise ValueError(
f"manual_divisions must have {expected} entries, "
f"but got {len(manual_divisions)}"
)
return PDMuxConfig(
sm_group_num=raw["sm_group_num"],
manual_divisions=manual_divisions,
split_forward_token_budget=raw.get("split_forward_token_budget", 65536),
decode_bs_divisor=raw.get("decode_bs_divisor", 36),
)
def get_arch_constraints(compute_capability):
major, minor = compute_capability
# green context constraints for different architectures
if major == 6:
return 1, 1 # min_per_part, multiple
elif major == 7:
return 2, 2
elif major == 8:
return 4, 2
elif major == 9 and minor >= 0:
return 8, 8
else:
raise ValueError(f"Unsupported compute capability: {major}.{minor}")
def divide_sm(total_sms, compute_capability, groups):
"""
:param total_sms: total sm count on a single GPU
:param compute_capability: (major, minor)
:return: SM partition group(prefill sm, decode sm)
"""
min_per_part, multiple = get_arch_constraints(compute_capability)
possible_values = [
x
for x in range(min_per_part, total_sms - min_per_part + 1, multiple)
if x >= total_sms - x and total_sms - x >= 16
]
if not possible_values:
raise ValueError(
f"No valid partitions found for total SMs {total_sms} "
f"with constraints (min per part: {min_per_part}, multiple: {multiple})"
)
if len(possible_values) >= groups:
step = max(1, len(possible_values) // groups)
selected_values = possible_values[::step][:groups]
else:
selected_values = possible_values
divisions = []
for part1 in selected_values:
part2 = total_sms - part1
divisions.append((part1, part2))
divisions.reverse() # Reverse to have larger prefill SM first
return divisions
def initialize_stream_groups(gpu_id: int, config: PDMuxConfig):
from sgl_kernel import spatial
global STREAM_GROUPS, SM_COUNTS, SM_GROUP_NUM, CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
# for pd_multiplexing, Init stream_groups
device = torch.cuda.current_device()
total_sm_count = spatial.get_sm_available(gpu_id)
# (prefill_sm_count, decode_sm_count)
if config.manual_divisions:
divisions = [
(prefill_sm, decode_sm)
for prefill_sm, decode_sm, _ in config.manual_divisions
]
else:
divisions = divide_sm(
total_sm_count,
torch.cuda.get_device_capability(device),
config.sm_group_num - 2,
)
SM_COUNTS = []
SM_COUNTS.append((total_sm_count, 0)) # Normal stream for prefill
SM_COUNTS.extend(divisions) # Add the divided SM counts
SM_COUNTS.append((0, total_sm_count)) # Normal stream for decode
STREAM_GROUPS = []
STREAM_GROUPS.append(
(torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
) # Normal stream for prefill
for prefill_sm, decode_sm in divisions:
STREAM_GROUPS.append(
(spatial.create_greenctx_stream_by_value(prefill_sm, decode_sm, gpu_id))
)
STREAM_GROUPS.append(
(torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
) # Normal stream for decode
CURRENT_STREAM_IDX = 0
CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
def set_current_stream_idx(idx: int):
global CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
if idx < 0 or idx >= len(STREAM_GROUPS):
raise ValueError(f"Invalid stream index: {idx}")
CURRENT_STREAM_IDX = idx
CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
def get_stream_groups() -> list[tuple[torch.cuda.Stream, torch.cuda.Stream]]:
"""Get the stream groups."""
return STREAM_GROUPS
def get_sm_counts() -> list[tuple[int, int]]:
"""Get the SM counts."""
return SM_COUNTS
def get_current_stream_idx() -> int:
"""Get the current stream index."""
return CURRENT_STREAM_IDX
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment