Unverified Commit caaa482a authored by zzzzwwjj's avatar zzzzwwjj Committed by GitHub
Browse files

[platform] Support additional forward context for OOT (#31674)


Signed-off-by: default avatarzzzzwwjj <1183291235@qq.com>
Signed-off-by: default avatarzzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent b471aad4
......@@ -4,7 +4,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, NamedTuple
import torch
......@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
......@@ -206,6 +207,8 @@ class ForwardContext:
ubatch_slices: UBatchSlices | None = None
additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
......@@ -236,6 +239,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
additional_kwargs: dict[str, Any] | None = None,
):
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
......@@ -245,6 +249,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
additional_kwargs=additional_kwargs or {},
)
......@@ -310,6 +315,17 @@ def set_forward_context(
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata,
vllm_config=vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
)
forward_context = create_forward_context(
attn_metadata,
vllm_config,
......@@ -318,6 +334,7 @@ def set_forward_context(
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
additional_kwargs,
)
try:
......@@ -330,8 +347,6 @@ def set_forward_context(
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
......
......@@ -693,6 +693,13 @@ class Platform:
"""
return max_model_len
@classmethod
def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
"""
Set some additional forward context for the current platform if needs.
"""
return {}
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment