Unverified Commit caaa482a authored by zzzzwwjj's avatar zzzzwwjj Committed by GitHub
Browse files

[platform] Support additional forward context for OOT (#31674)


Signed-off-by: default avatarzzzzwwjj <1183291235@qq.com>
Signed-off-by: default avatarzzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent b471aad4
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, NamedTuple from typing import Any, NamedTuple
import torch import torch
...@@ -13,6 +13,7 @@ import vllm.envs as envs ...@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm.v1.worker.ubatch_utils import UBatchSlices
...@@ -206,6 +207,8 @@ class ForwardContext: ...@@ -206,6 +207,8 @@ class ForwardContext:
ubatch_slices: UBatchSlices | None = None ubatch_slices: UBatchSlices | None = None
additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
...@@ -236,6 +239,7 @@ def create_forward_context( ...@@ -236,6 +239,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
additional_kwargs: dict[str, Any] | None = None,
): ):
return ForwardContext( return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context, no_compile_layers=vllm_config.compilation_config.static_forward_context,
...@@ -245,6 +249,7 @@ def create_forward_context( ...@@ -245,6 +249,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
additional_kwargs=additional_kwargs or {},
) )
...@@ -310,6 +315,17 @@ def set_forward_context( ...@@ -310,6 +315,17 @@ def set_forward_context(
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata,
vllm_config=vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
)
forward_context = create_forward_context( forward_context = create_forward_context(
attn_metadata, attn_metadata,
vllm_config, vllm_config,
...@@ -318,6 +334,7 @@ def set_forward_context( ...@@ -318,6 +334,7 @@ def set_forward_context(
cudagraph_runtime_mode, cudagraph_runtime_mode,
batch_descriptor, batch_descriptor,
ubatch_slices, ubatch_slices,
additional_kwargs,
) )
try: try:
...@@ -330,8 +347,6 @@ def set_forward_context( ...@@ -330,8 +347,6 @@ def set_forward_context(
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch
from vllm.platforms import current_platform
synchronize = current_platform.synchronize synchronize = current_platform.synchronize
if synchronize is not None: if synchronize is not None:
synchronize() synchronize()
......
...@@ -693,6 +693,13 @@ class Platform: ...@@ -693,6 +693,13 @@ class Platform:
""" """
return max_model_len return max_model_len
@classmethod
def set_additional_forward_context(cls, *args, **kwargs) -> dict[str, Any]:
"""
Set some additional forward context for the current platform if needs.
"""
return {}
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment