Unverified Commit 74f441f4 authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[Core] Allow full cudagraph with separate attention routines and orthogonal to...


[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: default avatarfhl <2410591650@qq.com>
Signed-off-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent a0632a3e
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.config import CUDAGraphMode
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
......@@ -100,16 +101,17 @@ class XPUPlatform(Platform):
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if model_config is not None:
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
if model_config is not None and model_config.dtype == torch.bfloat16 \
and not cls.device_support_bf16():
model_config.dtype = torch.float16
compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[XPU] CUDA graph is not supported on XPU, "
"disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# check and update parallel config
parallel_config = vllm_config.parallel_config
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import numpy as np
import torch
......@@ -154,9 +154,26 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS
# FA3:
# Supports full cudagraphs for all cases.
#
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# Theres probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = AttentionCGSupport.ALWAYS \
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
......@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if self.use_full_cuda_graph:
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
capture_sizes = self.compilation_config.cudagraph_capture_sizes
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
if self.use_full_cuda_graph and self.aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
......@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=causal)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
......@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
......@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
causal=causal)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
......
......@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention
......@@ -183,8 +183,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
......@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
......@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
......
......@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1
......@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
......@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
assert m.max_query_len == 1 # decode-only
return self.build(0, m)
......@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
......
......@@ -22,7 +22,7 @@ logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
class CutlassMLABackend(MLACommonBackend):
......
......@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
......@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
......@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
if self.compilation_config.full_cuda_graph:
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
......
......@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
......@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers
if vllm_config.compilation_config.full_cuda_graph:
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
......@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
if self.compilation_config.full_cuda_graph:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
......
......@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
......
......@@ -58,8 +58,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
......@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported
return True
class TritonAttentionBackend(AttentionBackend):
......
......@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
......@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
attention support and what cudagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
cudagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (commuicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False))
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
and cudagraph_mode.separate_routine():
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
cudagraph_capture_sizes_for_decode = [
x for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("cudagraph dispatching keys are not "
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# also check if non-uniform key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
......@@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
......@@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
from vllm.forward_context import DPMetadata, set_forward_context
from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
......@@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config),
)
self.use_cuda_graph = (
self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
......@@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# Cache the device properties.
self._init_device_properties()
......@@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
self.uniform_decode_query_len = 1 if not self.speculative_config else \
1 + self.speculative_config.num_speculative_tokens
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (MultiModalBudget(
self.model_config,
self.scheduler_config,
......@@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
......@@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str,
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata]]:
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata
]
"""
......@@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.use_cuda_graph
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
......@@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
for g in self._attn_group_iterator())
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens,
spec_decode_common_attn_metadata)
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
def _compute_cascade_attn_prefix_len(
self,
......@@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return mm_embeds
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper):
return self.model.unwrap()
return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
......@@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
assert self.eplb_state is not None
assert is_mixture_of_experts(self.model)
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state.step(
self.model,
model,
is_dummy,
is_profile,
log_stats=self.parallel_config.eplb_log_balancedness,
......@@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config)
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens_np,
spec_decode_common_attn_metadata) = (
self._prepare_inputs(scheduler_output))
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Use CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
......@@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model.
# Use persistent buffers for CUDA graphs.
......@@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
positions=positions,
......@@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model.compile(
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
return
# for other compilation levels, cudagraph behavior is controlled by
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.model = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
model = self.get_model()
model_loader.load_weights(model, model_config=self.model_config)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
model = self.get_model()
TensorizerLoader.save_model(
self.model,
model,
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
......@@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
capture_attn_cudagraph: bool = False,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
CUDA graph for the model.
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
needed.
force_attention: If True, always create attention metadata. Used to
warm up attention backend when mode is NONE.
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
"""
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
# uniform decode batches. A uniform decode batch means that all
# requests have identical query length, except a potential virtual
# request (shorter) in the batch account for padding.
# Uniform decode batch could either be common pure decode, where
# max_query_len == 1, or speculative decode, where
# max_query_len == 1 + num_spec_decode_tokens.
# When setting max_query_len = 1, we switch to and capture the optimized
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
if uniform_decode:
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
else:
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == \
CUDAGraphMode.FULL:
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
......@@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
max_query_len=max_query_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
......@@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor):
outputs = self.model(
input_ids=input_ids,
positions=positions,
......@@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device=self.device)
model = cast(VllmModelForPooling, self.model)
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
......@@ -2546,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
gc.collect()
def capture_model(self) -> None:
if not self.use_cuda_graph:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to "
"False", CompilationLevel.PIECEWISE)
"ensure `cudagraph_mode` was not manually set to `NONE`")
return
else:
self.initialize_cudagraph_capture()
compilation_counter.num_gpu_runner_capture_triggers += 1
......@@ -2576,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
full_cg = self.full_cuda_graph
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
compilation_cases = tqdm(
list(compilation_cases),
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graph shapes")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False)
# Capture full cudagraph for uniform decode batches if we have
# dont already have full mixed prefill-decode cudagraphs
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
cudagraph_mode.separate_routine():
max_num_tokens = self.scheduler_config.max_num_seqs * \
self.uniform_decode_query_len
decode_cudagraph_batch_sizes = [
x for x in self.cudagraph_batch_sizes if
x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable cudagraph capturing globally, so any unexpected cudagraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
......@@ -2604,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name))
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
skip_eplb=True)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
skip_eplb=True)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
......@@ -2648,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builder_i,
layer_names)
attn_groups.append(attn_group)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
......@@ -2734,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"All or none of the layers are expected to be encoder-only"
self.is_encoder_only_model = True
def initialize_cudagraph_capture(self) -> None:
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.metadata_builder
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_cg_support != AttentionCGSupport.ALWAYS:
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full cudagraph related mode
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_AND_PIECEWISE
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_cg_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_builder_name} (support: {min_cg_support})")
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if cudagraph_mode.has_full_cudagraphs() \
and min_cg_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_builder_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
# Trigger cudagraph dispatching keys initialization here (after
# initializing attn backends).
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering
......
......@@ -322,16 +322,11 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment