Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
......@@ -34,7 +34,7 @@ import torch.nn.functional as F
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
......@@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
)
self.scale = self.hidden_size_per_attention_head**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
self.scale,
......
......@@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
......@@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(
# Use unified MMEncoderAttention with Flash Attention support
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
......@@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation
# Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output
......
......@@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
......@@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
disable_tp=use_data_parallel,
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
......
......@@ -14,7 +14,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)"""
......@@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.projection_layer(x)
......
......@@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling
)
......
......@@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
......@@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
)
......
......@@ -169,10 +169,13 @@ class DeciLMDecoderLayer(nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if not self._is_no_op_ffn:
ffn_mult = block_config.ffn.ffn_mult
intermediate_size = _ffn_mult_to_intermediate_size(
ffn_mult, config.hidden_size
)
if hasattr(block_config.ffn, "ffn_mult"):
ffn_mult = block_config.ffn.ffn_mult
intermediate_size = _ffn_mult_to_intermediate_size(
ffn_mult, config.hidden_size
)
else:
intermediate_size = block_config.ffn.intermediate_size
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
......
......@@ -70,7 +70,6 @@ from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
NestedTensors,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
......@@ -1150,27 +1149,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
handle_oov_mm_token=handle_oov_mm_token,
)
def embed_multimodal_v0(self, **kwargs: object) -> NestedTensors | None:
audio_input = self._parse_and_validate_audio_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if audio_input is None and image_input is None and video_input is None:
return None
multimodal_embeddings: list[tuple[NestedTensors, str]] = []
if audio_input is not None:
audio_embeds = self._process_audio_input(audio_input)
multimodal_embeddings.append((audio_embeds, "audio"))
if image_input is not None:
image_embeds = self._process_image_input(image_input)
multimodal_embeddings.append((image_embeds, "image"))
if video_input is not None:
video_embeds = self._process_video_input(video_input)
multimodal_embeddings.append((video_embeds, "video"))
return multimodal_embeddings
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward(
self,
input_ids: torch.Tensor,
......@@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(
self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
hf_config = vllm_config.model_config.hf_config
kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else:
return BertModel(
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
return JinaRobertaModel(**kwargs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
......
......@@ -16,8 +16,8 @@ from transformers import (
SiglipVisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
......@@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
......@@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
......@@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
......@@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers
......
......@@ -15,7 +15,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(
self,
......@@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output)
......
......@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
......@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
......@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
......@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
......@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
......@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)
......
......@@ -5,6 +5,7 @@
"""PyTorch Ultravox model."""
import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
......@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
)
hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
......@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
......
......@@ -16,9 +16,9 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
]
class WhisperEncoderAttention(MultiHeadAttention):
class WhisperEncoderAttention(MMEncoderAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(
......
......@@ -10,7 +10,7 @@ import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
......@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int):
def _get_fp8_gemm_nt_m_values(w: torch.Tensor, max_tokens: int) -> list[int]:
"""Get the M values to warmup for a given weight tensor."""
n, _ = w.size()
device = w.device
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
return _generate_optimal_warmup_m_values(max_tokens, n, device)
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
return list(range(1, max_tokens + 1))
def _deepgemm_fp8_gemm_nt_warmup(
w: torch.Tensor,
ws: torch.Tensor,
max_tokens: int,
pbar: tqdm | None = None,
):
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
return
......@@ -189,27 +212,14 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
m_values = list(range(1, max_tokens + 1))
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
pbar = tqdm(total=len(m_values), desc=desc)
m_values = _get_fp8_gemm_nt_m_values(w, max_tokens)
for num_tokens in m_values:
fp8_gemm_nt(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
)
pbar.update(1)
if pbar is not None:
pbar.update(1)
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
......@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
def _get_grouped_gemm_params(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
max_tokens: int,
):
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
) -> tuple[int, int, torch.Tensor]:
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = get_mk_alignment_for_contiguous_layout()[0]
......@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
return MAX_M, block_m, expert_ids
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
max_tokens: int,
pbar: tqdm | None = None,
):
if (
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
MAX_M, block_m, expert_ids = _get_grouped_gemm_params(w1, w2, num_topk, max_tokens)
device = w1.device
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
......@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
# Generate M values in block_m increments (already optimized for MoE)
m_values = list(range(block_m, MAX_M + 1, block_m))
pbar = tqdm(
total=len(m_values),
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
f"[{len(m_values)} values, block_m={block_m}]",
)
for num_tokens in m_values:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
......@@ -277,7 +293,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
out[:num_tokens],
expert_ids[:num_tokens],
)
pbar.update(1)
if pbar is not None:
pbar.update(1)
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
......@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
def deepgemm_fp8_gemm_nt_warmup(
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
):
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
for dgm in dg_modules:
w, ws, _ = _extract_data_from_linear_base_module(dgm)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens, pbar=pbar)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
model: torch.nn.Module, max_tokens: int
model: torch.nn.Module, max_tokens: int, pbar: tqdm | None = None
):
dg_modules = [
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
......@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
dgm
)
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk, max_tokens
w13, w2, w13_scale, w2_scale, num_topk, max_tokens, pbar=pbar
)
def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
seen_fp8_sizes: set[torch.Size] = set(FP8_GEMM_NT_WARMUP_CACHE)
seen_grouped_sizes: set[torch.Size] = set(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
)
total = 0
for m in model.modules():
if _fp8_linear_may_use_deep_gemm(m):
w, _, _ = _extract_data_from_linear_base_module(m)
if w.size() not in seen_fp8_sizes:
total += len(_get_fp8_gemm_nt_m_values(w, max_tokens))
seen_fp8_sizes.add(w.size())
elif _fused_moe_grouped_gemm_may_use_deep_gemm(m):
w13, _, w2, _, num_topk = _extract_data_from_fused_moe_module(m)
if w13.size() in seen_grouped_sizes and w2.size() in seen_grouped_sizes:
continue
MAX_M, block_m, _ = _get_grouped_gemm_params(w13, w2, num_topk, max_tokens)
n_values = (MAX_M - block_m) // block_m + 1
if w13.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w13.size())
if w2.size() not in seen_grouped_sizes:
total += n_values
seen_grouped_sizes.add(w2.size())
return total
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
total = _count_warmup_iterations(model, max_tokens)
if total == 0:
return
# Only show progress bar on rank 0 to avoid cluttered output
if is_global_first_rank():
with tqdm(total=total, desc="DeepGEMM warmup") as pbar:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, pbar)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, pbar)
else:
deepgemm_fp8_gemm_nt_warmup(model, max_tokens, None)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens, None)
......@@ -49,13 +49,12 @@ def kernel_warmup(worker: "Worker"):
except NotImplementedError:
return False
# NOTE: we add check for empty attn_groups to avoid errors when
# deploying models such as E instances and encoder-only models.
# As for those models, worker.model_runner.attn_groups is empty.
# This change is made during EPD feature development.
if (
not worker.model_runner.is_pooling_model
and worker.model_runner.attn_groups
# NOTE: This should be `any` instead of `all` but other hybrid attention
# backends don't support this dummy run. Once we remove
# `build_for_cudagraph_capture`, we can change it to `any`.
and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups
......
......@@ -124,11 +124,9 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
# GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
# ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
# ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
......@@ -141,7 +139,6 @@ def use_rocm_custom_paged_attention(
# and (gqa_ratio >= 1 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and not (rocm_aiter_ops.is_pa_attn_enabled())
# and sinks is None
# )
......
......@@ -162,7 +162,10 @@ class XPUPlatform(Platform):
# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
# Only override worker_cls if it's still the default "auto"
# This allows custom workers (like vllm-omni workers) to be used on XPU
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True
......
......@@ -26,6 +26,8 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
thinking = bool(chat_kwargs.pop("thinking", False))
enable_thinking = bool(chat_kwargs.pop("enable_thinking", False))
thinking = thinking or enable_thinking
if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
......
......@@ -50,6 +50,8 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
**kwargs,
) -> str | list[int]:
thinking = kwargs.get("thinking", False)
enable_thinking = kwargs.get("enable_thinking", False)
thinking = thinking or enable_thinking
thinking_mode = "thinking"
if not thinking:
thinking_mode = "chat"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment