Unverified Commit 56793995 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: default avataryewentao256 <zhyanwentao@126.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 08369289
...@@ -87,6 +87,11 @@ def parse_args(): ...@@ -87,6 +87,11 @@ def parse_args():
default=0.8, default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
) )
parser.add_argument(
"--enable-dbo",
action="store_true",
help=("Enable microbatched execution"),
)
parser.add_argument( parser.add_argument(
"--compilation-config", "--compilation-config",
type=int, type=int,
...@@ -113,6 +118,7 @@ def main( ...@@ -113,6 +118,7 @@ def main(
max_model_len, max_model_len,
compilation_config, compilation_config,
gpu_memory_utilization, gpu_memory_utilization,
enable_dbo,
quantization, quantization,
): ):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
...@@ -167,6 +173,7 @@ def main( ...@@ -167,6 +173,7 @@ def main(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len, max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
enable_dbo=enable_dbo,
quantization=quantization, quantization=quantization,
compilation_config=compilation_config, compilation_config=compilation_config,
) )
...@@ -227,6 +234,7 @@ if __name__ == "__main__": ...@@ -227,6 +234,7 @@ if __name__ == "__main__":
args.max_model_len, args.max_model_len,
args.compilation_config, args.compilation_config,
args.gpu_memory_utilization, args.gpu_memory_utilization,
args.enable_dbo,
args.quantization, args.quantization,
), ),
) )
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import create_common_attn_metadata from tests.v1.attention.utils import create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UbatchSlice, from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice, _make_metadata_with_slice,
slice_query_start_locs, slice_query_start_locs,
split_attn_metadata) split_attn_metadata)
...@@ -106,7 +106,7 @@ def mixed_small_metadata(): ...@@ -106,7 +106,7 @@ def mixed_small_metadata():
def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
"""Test slicing decode batch metadata""" """Test slicing decode batch metadata"""
# Split first request only # Split first request only
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1))
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
...@@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): ...@@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
"""Test slicing mixed batch metadata""" """Test slicing mixed batch metadata"""
ubatch_slice = UbatchSlice(slice(1, 3), ubatch_slice = UBatchSlice(slice(1, 3),
slice(1, 7)) # Requests 1-3, tokens 1-7 slice(1, 7)) # Requests 1-3, tokens 1-7
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
...@@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): ...@@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
num_tokens = large_decode_metadata.num_reqs num_tokens = large_decode_metadata.num_reqs
mid_point = num_tokens // 2 mid_point = num_tokens // 2
ubatch_slices = [ ubatch_slices = [
UbatchSlice(slice(0, mid_point), slice(0, mid_point)), UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
num_tokens)), num_tokens)),
] ]
......
...@@ -365,7 +365,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -365,7 +365,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Mock runner for attention metadata building # Mock runner for attention metadata building
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
result = proposer.propose(target_token_ids=target_token_ids, result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
...@@ -489,7 +491,9 @@ def test_propose_tree(spec_token_tree): ...@@ -489,7 +491,9 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building. # Mock runner for attention metadata building.
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
# Setup inputs for the proposer. # Setup inputs for the proposer.
target_token_ids = torch.randint(0, target_token_ids = torch.randint(0,
......
...@@ -2848,6 +2848,14 @@ class VllmConfig: ...@@ -2848,6 +2848,14 @@ class VllmConfig:
"when cudagraph_mode piecewise cudagraphs is used, "\ "when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}" f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend == "deepep_low_latency", \
"Microbatching currently only supports the deepep_low_latency "\
f"all2all backend. {a2a_backend} is not supported. To fix set "\
"the VLLM_ALL2ALL_BACKEND environment variable to "\
"deepep_low_latency and install the DeepEP kerenls."
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
......
...@@ -137,6 +137,14 @@ class ParallelConfig: ...@@ -137,6 +137,14 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL.""" """Disable the custom all-reduce kernel and fall back to NCCL."""
enable_dbo: bool = False
"""Enable microbatching for the model executor."""
dbo_decode_token_threshold: int = 32
"""The threshold for microbatching. If the number of tokens in the
request is greater than this threshold, microbatching will be used.
Otherwise, the request will be processed in a single batch."""
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
......
...@@ -251,9 +251,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -251,9 +251,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
logger.debug("DeepEP all2all args %s", buffer_kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create( handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer) buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle return handle
...@@ -327,6 +327,9 @@ class EngineArgs: ...@@ -327,6 +327,9 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = \
ParallelConfig.dbo_decode_token_threshold
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \ expert_placement_strategy: ExpertPlacementStrategy = \
...@@ -695,6 +698,11 @@ class EngineArgs: ...@@ -695,6 +698,11 @@ class EngineArgs:
parallel_group.add_argument( parallel_group.add_argument(
"--enable-expert-parallel", "--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-dbo",
**parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"])
parallel_group.add_argument("--enable-eplb", parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"]) **parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config", parallel_group.add_argument("--eplb-config",
...@@ -1339,6 +1347,8 @@ class EngineArgs: ...@@ -1339,6 +1347,8 @@ class EngineArgs:
data_parallel_backend=self.data_parallel_backend, data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config, eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy, expert_placement_strategy=self.expert_placement_strategy,
......
...@@ -14,6 +14,7 @@ import vllm.envs as envs ...@@ -14,6 +14,7 @@ import vllm.envs as envs
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -97,6 +98,53 @@ class DPMetadata: ...@@ -97,6 +98,53 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group) dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu() return num_tokens_tensor.cpu()
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int, dp_size: int,
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
device = current_platform.device_type
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(tensor, group=get_dp_group().device_group)
result: bool = bool(torch.all(tensor[2] == 1).item())
if not result:
return result, None
orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :]
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
padded_max_num_tokens)
return False, None
return result, padded_num_tokens_tensor.cpu()
@staticmethod @staticmethod
def make( def make(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
...@@ -119,14 +167,15 @@ class DPMetadata: ...@@ -119,14 +167,15 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce # If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
or num_tokens_across_dp[dp_rank] == batchsize) == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
if num_tokens_across_dp is None: if num_tokens_across_dp is None:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank) batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
@contextmanager @contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
...@@ -179,9 +228,12 @@ class ForwardContext: ...@@ -179,9 +228,12 @@ class ForwardContext:
Type AttentionMetadata for v0, Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata attention layer to its attention metadata
set dynamically for each forward pass Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch.
Set dynamically for each forward pass
""" """
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
list[dict[str, "AttentionMetadata"]]]
# TODO: remove after making all virtual_engines share the same kv cache # TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
...@@ -191,6 +243,8 @@ class ForwardContext: ...@@ -191,6 +243,8 @@ class ForwardContext:
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None batch_descriptor: Optional[BatchDescriptor] = None
ubatch_slices: Optional[UBatchSlices] = None
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode in [ assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
...@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext: ...@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
return _forward_context return _forward_context
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: Optional[DPMetadata] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices)
@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
"""A context manager that overrides the current forward context.
This is used to override the forward context for a specific
forward pass.
"""
global _forward_context
prev_context = _forward_context
_forward_context = forward_context
try:
yield
finally:
_forward_context = prev_context
@contextmanager @contextmanager
def set_forward_context( def set_forward_context(
attn_metadata: Any, attn_metadata: Any,
...@@ -216,7 +303,8 @@ def set_forward_context( ...@@ -216,7 +303,8 @@ def set_forward_context(
num_tokens: Optional[int] = None, num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None, num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None): batch_descriptor: Optional[BatchDescriptor] = None,
ubatch_slices: Optional[UBatchSlices] = None):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
...@@ -225,6 +313,7 @@ def set_forward_context( ...@@ -225,6 +313,7 @@ def set_forward_context(
need_to_track_batchsize = track_batchsize and attn_metadata is not None need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize: if need_to_track_batchsize:
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1 and ( if vllm_config.parallel_config.data_parallel_size > 1 and (
attn_metadata is not None or num_tokens is not None): attn_metadata is not None or num_tokens is not None):
...@@ -232,20 +321,14 @@ def set_forward_context( ...@@ -232,20 +321,14 @@ def set_forward_context(
attn_metadata, num_tokens or 0, attn_metadata, num_tokens or 0,
num_tokens_across_dp) num_tokens_across_dp)
global _forward_context forward_context = create_forward_context(attn_metadata, vllm_config,
prev_context = _forward_context virtual_engine, dp_metadata,
_forward_context = ForwardContext( cudagraph_runtime_mode,
no_compile_layers=vllm_config.compilation_config. batch_descriptor, ubatch_slices)
static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
)
try: try:
yield with override_forward_context(forward_context):
yield
finally: finally:
global last_logging_time, batchsize_logging_interval global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize: if need_to_track_batchsize:
...@@ -282,5 +365,3 @@ def set_forward_context( ...@@ -282,5 +365,3 @@ def set_forward_context(
logger.info(("Batchsize forward time stats " logger.info(("Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"), "(batchsize, count, median_time(ms)): %s"),
forward_stats) forward_stats)
_forward_context = prev_context
...@@ -191,7 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -191,7 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> Callable: ) -> tuple[Callable, mk.ReceiverType]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -217,13 +217,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -217,13 +217,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q_scale = None a1q_scale = None
a1_post_scale = a1_scale a1_post_scale = a1_scale
return self._do_dispatch(tokens=a1q, return (lambda *args: None,
token_scales=a1q_scale, self._do_dispatch(tokens=a1q,
rank_topk_ids=topk_ids, token_scales=a1q_scale,
rank_topk_weights=topk_weights, rank_topk_ids=topk_ids,
num_experts=num_experts, rank_topk_weights=topk_weights,
a1_scale=a1_post_scale, num_experts=num_experts,
quant_config=quant_config) a1_scale=a1_post_scale,
quant_config=quant_config))
def prepare( def prepare(
self, self,
...@@ -237,10 +238,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -237,10 +238,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, (_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
topk_ids, num_experts, expert_map, topk_weights, topk_ids, num_experts,
apply_router_weight_on_input, expert_map,
quant_config) apply_router_weight_on_input,
quant_config)
return receiver() return receiver()
def finalize( def finalize(
......
...@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape) moe_kernel_quantize_input, normalize_batched_scales_shape)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
# DeepEP kernels quantize dispatch inputs in 128 element chunks. # DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SIZE = 128
...@@ -55,7 +58,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -55,7 +58,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
self.handle = None self.handles: list[Optional[tuple]] = [None, None]
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
...@@ -123,13 +126,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -123,13 +126,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.ReceiverType: ) -> tuple[Callable, mk.ReceiverType]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
(f"Hidden Size {hidden_size} not in supported list of hidden sizes" (f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}") f"{self.SUPPORTED_HIDDEN_SIZES}")
a2a_idx = dbo_current_ubatch_id()
if self.use_fp8_dispatch: if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \ assert hidden_size % 128 == 0, \
"DeepEP kernels quantize the inputs in blocks of shape 128" "DeepEP kernels quantize the inputs in blocks of shape 128"
...@@ -148,7 +153,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -148,7 +153,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch # Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \ expert_x, expert_num_tokens, handle, _, hook= \
self.buffer.low_latency_dispatch(a1, self.buffer.low_latency_dispatch(a1,
topk_ids, topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
...@@ -156,21 +161,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -156,21 +161,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=True) return_recv_hook=True)
self.handles[a2a_idx] = handle
return lambda: self._receiver(hook, expert_x, expert_num_tokens, return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config) a1_scale, a1.dtype, quant_config))
def _receiver( def _receiver(
self, self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor, expert_num_tokens: torch.Tensor,
a1_scale, a1_scale,
a1_dtype, a1_dtype,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
hook()
expert_x, expert_x_scale = self._do_quant( expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
...@@ -192,10 +195,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -192,10 +195,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
topk_ids, num_experts, expert_map, topk_weights, topk_ids,
apply_router_weight_on_input, num_experts, expert_map,
quant_config) apply_router_weight_on_input,
quant_config)
hook()
return receiver() return receiver()
def finalize( def finalize(
...@@ -210,7 +215,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -210,7 +215,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
assert self.handle is not None
a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled()
handle = self.handles[a2a_idx]
assert handle is not None
combine_topk_weights = topk_weights combine_topk_weights = topk_weights
if apply_router_weight_on_input: if apply_router_weight_on_input:
...@@ -218,12 +227,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -218,12 +227,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights) combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode # TODO (varun) : Enable zero copy mode
_, event, hook = self.buffer.low_latency_combine( dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output, fused_expert_output,
topk_ids, topk_ids,
combine_topk_weights, combine_topk_weights,
self.handle, handle,
async_finish=False, async_finish=False,
zero_copy=False, zero_copy=False,
return_recv_hook=False, return_recv_hook=do_recv_hook,
out=output) out=output)
if recv_hook is not None:
dbo_register_recv_hook(recv_hook)
dbo_yield()
...@@ -38,6 +38,7 @@ from vllm.platforms import current_platform ...@@ -38,6 +38,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up) round_up)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -992,16 +993,28 @@ class FusedMoE(CustomOp): ...@@ -992,16 +993,28 @@ class FusedMoE(CustomOp):
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels): or self.moe_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros( if vllm_config.parallel_config.enable_dbo:
(moe.max_num_tokens, self.hidden_size), self.batched_hidden_states = torch.zeros(
dtype=moe.in_dtype, (2, moe.max_num_tokens, self.hidden_size),
device=torch.cuda.current_device()) dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(2, moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
else:
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count # Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros( self.batched_router_logits = torch.zeros(
(moe.max_num_tokens, num_experts), (moe.max_num_tokens, num_experts),
dtype=moe.in_dtype, dtype=moe.in_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> Optional[torch.nn.Module]:
...@@ -1708,14 +1721,29 @@ class FusedMoE(CustomOp): ...@@ -1708,14 +1721,29 @@ class FusedMoE(CustomOp):
hidden_states = full_hidden_states[chunk_start:chunk_end, :] hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :]
assert (self.batched_hidden_states.size(0) # type: ignore assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if self.batched_hidden_states.dim() == 3:
assert self.batched_router_logits.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
batched_hidden_states = self.batched_hidden_states[
batch_buffer_idx, :]
batched_router_logits = self.batched_router_logits[
batch_buffer_idx, :]
else:
batched_hidden_states = self.batched_hidden_states
batched_router_logits = self.batched_router_logits
assert (batched_hidden_states.size(0) # type: ignore
>= chunk_size) >= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore assert (batched_router_logits.size(0) # type: ignore
>= chunk_size) >= chunk_size)
staged_hidden_states = self.batched_hidden_states[: staged_hidden_states = batched_hidden_states[:
chunk_size, :] # type: ignore chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[: staged_router_logits = batched_router_logits[:
chunk_size, :] # type: ignore chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)
......
...@@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig ...@@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens) _resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
# #
# This file defines a set of base classes used to make MoE kernels more modular. # This file defines a set of base classes used to make MoE kernels more modular.
...@@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> ReceiverType: ) -> tuple[Callable, ReceiverType]:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers. but do not wait for results from other workers.
...@@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, ...@@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
return None return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape)
@final @final
class FusedMoEModularKernel(torch.nn.Module): class FusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component layer due to any layer specific state that may be used by the component
objects. objects.
""" """
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
def __init__( def __init__(
self, self,
...@@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape), workspace13 = self.workspace13_buffer.get(workspace13_shape,
device=a1.device, device=a1.device,
dtype=workspace_dtype) dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape), workspace2 = self.workspace2_buffer.get(workspace2_shape,
device=a1.device, device=a1.device,
dtype=workspace_dtype) dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, ( assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}") f"fused_out {fused_out.shape} but expected {fused_out_shape}")
...@@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta) expert_tokens_meta)
fused_out = torch.empty(fused_out_shape, fused_out = self.fused_out_buffer.get(fused_out_shape,
device=a1q.device, device=a1q.device,
dtype=a1.dtype) dtype=a1.dtype)
def slice_input_tensors( def slice_input_tensors(
chunk_idx: int chunk_idx: int
...@@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
shared_output: torch.Tensor shared_output: torch.Tensor
if (not self.prepare_finalize.supports_async() if not self.prepare_finalize.supports_async():
or self.shared_experts is None): # We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()
# Run shared experts serially with dispatch. # Run shared experts serially with dispatch.
if self.shared_experts is not None: if self.shared_experts is not None:
...@@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
) )
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async( dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1, a1,
a1_scale, a1_scale,
a2_scale, a2_scale,
...@@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config, self.fused_experts.quant_config,
) )
assert self.shared_experts is not None if self.shared_experts is not None:
shared_output = self.shared_experts(a1) shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver() _expert_topk_weights) = receiver()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Callable, Optional, Union
import pplx_kernels as pplx import pplx_kernels as pplx
import torch import torch
...@@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -103,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.ReceiverType: ) -> tuple[Callable, mk.ReceiverType]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -214,41 +214,33 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -214,41 +214,33 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=False, do_recv=False,
) )
return lambda: self._receiver( hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
return (hook, lambda: self._receiver(
expert_num_tokens, expert_num_tokens,
expert_x, expert_x,
expert_x_scale, expert_x_scale,
a1q,
a1q_scale,
topk_ids,
bound_m,
orig_a_scale_block_shape, orig_a_scale_block_shape,
) ))
def _receiver( def _receiver(
self, self,
expert_num_tokens: torch.Tensor, expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor, expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor], expert_x_scale: Optional[torch.Tensor],
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor,
bound_m: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int], orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
if expert_x_scale is not None: if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3 assert expert_x_scale.ndim == 3
...@@ -270,7 +262,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -270,7 +262,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
receiver = self.prepare_async( hook, receiver = self.prepare_async(
a1, a1,
a1_scale, a1_scale,
a2_scale, a2_scale,
...@@ -281,6 +273,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -281,6 +273,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input, apply_router_weight_on_input,
quant_config, quant_config,
) )
hook()
return receiver() return receiver()
def finalize( def finalize(
......
...@@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( ...@@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout) get_kv_connector_cache_layout)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__) logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"] KVCacheLayoutType = Literal["NHD", "HND"]
...@@ -81,12 +82,6 @@ class CommonAttentionMetadata: ...@@ -81,12 +82,6 @@ class CommonAttentionMetadata:
encoder_seq_lens: Optional[np.ndarray] = None encoder_seq_lens: Optional[np.ndarray] = None
@dataclass
class UbatchSlice:
request_slice: slice
token_slice: slice
def slice_query_start_locs( def slice_query_start_locs(
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
request_slice: slice, request_slice: slice,
...@@ -103,7 +98,7 @@ def slice_query_start_locs( ...@@ -103,7 +98,7 @@ def slice_query_start_locs(
def _make_metadata_with_slice( def _make_metadata_with_slice(
ubatch_slice: UbatchSlice, ubatch_slice: UBatchSlice,
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
""" """
This function creates a new CommonAttentionMetadata that corresponds to This function creates a new CommonAttentionMetadata that corresponds to
...@@ -133,6 +128,11 @@ def _make_metadata_with_slice( ...@@ -133,6 +128,11 @@ def _make_metadata_with_slice(
torch.max(torch.abs(query_start_loc_cpu[1:] - torch.max(torch.abs(query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])).item()) query_start_loc_cpu[:-1])).item())
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice] block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice] slot_mapping = attn_metadata.slot_mapping[token_slice]
...@@ -152,12 +152,12 @@ def _make_metadata_with_slice( ...@@ -152,12 +152,12 @@ def _make_metadata_with_slice(
def split_attn_metadata( def split_attn_metadata(
ubatch_slices: list[UbatchSlice], ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]: ) -> list[CommonAttentionMetadata]:
""" """
Creates a new CommonAttentionMetadata instance that corresponds to the Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UbatchSlice in ubatch_slices. requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata Note: This function does not modify common_attn_metadata
""" """
......
...@@ -27,6 +27,7 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata ...@@ -27,6 +27,7 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -179,9 +180,11 @@ class EagleProposer: ...@@ -179,9 +180,11 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ ubatch_id = dbo_current_ubatch_id()
.build_for_drafting(common_attn_metadata=common_attn_metadata, attn_metadata_builder = \
draft_index=0) self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
...@@ -355,8 +358,9 @@ class EagleProposer: ...@@ -355,8 +358,9 @@ class EagleProposer:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
ubatch_id = dbo_current_ubatch_id()
tree_attn_metadata_builder = \ tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builder self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
assert isinstance(tree_attn_metadata_builder, assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)
......
...@@ -64,8 +64,13 @@ class CPUModelRunner(GPUModelRunner): ...@@ -64,8 +64,13 @@ class CPUModelRunner(GPUModelRunner):
if not self.attn_groups[0]: if not self.attn_groups[0]:
return return
mb = getattr(self.attn_groups[0][0], "metadata_builder", None) mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
if not isinstance(mb, TorchSDPAMetadataBuilderV1): if isinstance(mb, list):
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
return
mb[0].reorder_batch(self.input_batch, scheduler_output)
return
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
# Encoder-only / rerank models do not benefit from reordering, # Encoder-only / rerank models do not benefit from reordering,
# so we safely skip here. # so we safely skip here.
return return
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import threading
from typing import Any, Callable, Optional
import torch
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
@dataclasses.dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
num_tokens: int
@dataclasses.dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None
class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, device: torch.cuda.device):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a cudagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the cuda context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its cuda context (torch.cuda.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured cudagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(ubatch_context.comm_stream):
_ = torch.cuda.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + \
ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_capture_ubatch_thread,
args=(
results,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the cudagraph
cudagraph_metadata = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata,
)
with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream,
pool=self.graph_pool):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
positions, inputs_embeds, intermediate_tensors,
compute_stream, dp_metadata, batch_descriptor,
cudagraph_runtime_mode) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode))
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
sliced_intermediate_tensors = \
self._slice_model_inputs(
ubatch_slice.token_slice, input_ids, positions,
inputs_embeds, intermediate_tensors)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop -
ubatch_slice.token_slice.start))
return ubatch_metadata
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
inputs_embeds, intermediate_tensors):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[
tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = intermediate_tensors[
tokens_slice] if intermediate_tensors else None
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
sliced_intermediate_tensors)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (ubatch_slices[0].token_slice.stop -
ubatch_slices[0].token_slice.start) * 2
input_ids = kwargs['input_ids']
positions = kwargs['positions']
intermediate_tensors = kwargs['intermediate_tensors']
inputs_embeds = kwargs['inputs_embeds']
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
if num_tokens not in self.cudagraphs \
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.cudagraphs:
cudagraph_metadata = self.cudagraphs[num_tokens]
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)
return self._run_ubatches(ubatch_metadata, self.model)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.forward_context import DPMetadata
from vllm.logger import init_logger
from vllm.utils import round_up
from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices,
is_second_ubatch_empty)
logger = init_logger(__name__)
def should_ubatch_with_num_tokens(
should_ubatch: bool,
orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int,
vllm_config: VllmConfig,
) -> tuple[bool, Optional[torch.Tensor]]:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
return DPMetadata.should_ubatch_across_dp(should_ubatch,
orig_num_tokens_per_ubatch,
padded_num_tokens_per_ubatch,
dp_size, dp_rank)
def get_dp_padding_ubatch(
num_tokens_unpadded: int, num_tokens_padded: int,
should_attempt_ubatching: bool,
vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]:
"""
1. Decides if each DP rank is going to microbatch. Either all ranks
run with microbatching or none of them do. If this function decides
not to run with microbatching. It will "abort" meaning that no padding
information will be returned to the caller. It will return (False, None)
2. Determines the total number of tokens that each rank will run.
All ranks will be padded out so that the run with the same number
of tokens
Returns: tuple[
should_ubatch: Are all DP ranks going to microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if should_ubatch if False
]
"""
assert num_tokens_padded >= num_tokens_unpadded
dp_size = vllm_config.parallel_config.data_parallel_size
if dp_size == 1:
# Early exit.
return False, None
# If this DP rank doesn't want to attempt microbatching
if not should_attempt_ubatching:
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
False, 0, 0, vllm_config)
assert should_ubatch is False
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
# Round up to the next multiple of two for even divisibility
num_tokens_padded = round_up(num_tokens_padded, 2)
num_tokens_per_ubatch = num_tokens_padded // 2
should_ubatch = True
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.debug(
"Empty second µbatch detected: unpadded tokens: %s, padded "
"tokens: %s", num_tokens_unpadded, num_tokens_padded)
should_ubatch = False
# Note that we compute the number of padded tokens per ubatch
(should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens(
should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch,
vllm_config)
if not should_ubatch:
assert num_tokens_across_dp is None
return should_ubatch, num_tokens_across_dp
assert num_tokens_across_dp is not None
max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item())
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return should_ubatch, num_tokens_after_padding
def ubatch_split(
max_num_scheduled_tokens: int,
num_tokens_unpadded: int,
num_tokens_padded: int,
vllm_config: VllmConfig,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
Coordinates amongst all DP ranks to determine if and how the full batch
should be split into microbatches.
Returns: tuple[
ubatch_slices: if this is set then all DP ranks have agreed to
microbatch
num_tokens_after_padding: A tensor containing the total number of
tokens per-microbatch for each DP rank including padding. Will be
None if ubatch_slices is None
]
"""
parallel_config = vllm_config.parallel_config
# Don't bother with the should_ubatch handshaking unless microbatching
# is enabled
if not parallel_config.enable_dbo:
return (None, None)
# Check preconditions for microbatching
should_attempt_ubatching = \
parallel_config.enable_dbo and \
num_tokens_unpadded >= \
parallel_config.dbo_decode_token_threshold \
and max_num_scheduled_tokens == 1
# Don't microbatch unless every other DP worker is also microbatching
num_tokens_after_padding = None
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
vllm_config)
if not should_ubatch:
return (None, None)
# This doesn't actually pad the ubatch slices. It just initializes the
# split point to the padded value so that padding can be applied
# to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation
assert num_tokens_after_padding is not None
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
num_tokens_unpadded)
# Note there's an assumption here that there's 1 token per request
ubatch_slices = [
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
]
return (ubatch_slices, num_tokens_after_padding)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing_extensions import TypeAlias
@dataclass
class UBatchSlice:
request_slice: slice
token_slice: slice
UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int,
padded_num_tokens_per_ubatch: int) -> bool:
return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment