Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
......@@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:
......
......@@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def override_quantization_method(
......@@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter(
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
......@@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
elif self.use_marlin:
# Marlin processing
prepare_moe_fp4_layer_for_marlin(layer)
......@@ -1530,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2_blockscale_swizzled, requires_grad=False
)
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
a1_gscale = layer.w13_input_scale_quant
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
is_sf_swizzled_layout=False,
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
......@@ -1584,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
hidden_states=x_routing,
router_logits=router_logits,
)
......
......@@ -95,12 +95,12 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
......
......@@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
else:
return False
def _is_fp8_w4a8(
self,
weight_quant: list[dict[str, Any]] | None,
input_quant: dict[str, Any] | None,
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
return False
# Confirm weight scheme is supported
is_w4a8_dtype = (
weight_quant[0].get("dtype") == "fp8_e4m3"
and weight_quant[1].get("dtype") == "int4"
and input_quant.get("dtype") == "fp8_e4m3"
)
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
1
].get("is_dynamic")
is_per_tensor_fp8_and_per_channel_int4_weight = (
weight_quant[0].get("qscheme") == "per_tensor"
and weight_quant[1].get("qscheme") == "per_channel"
and weight_quant[1].get("symmetric") is True
and weight_quant[1].get("ch_axis") == 0
)
if not (
is_w4a8_dtype
and is_static_weight
and is_per_tensor_fp8_and_per_channel_int4_weight
):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.get("is_dynamic"):
return True
# Confirm activation scheme is supported.
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
return is_per_tensor_activation
def _is_fp8_w8a8(
self,
weight_quant: dict[str, Any] | None,
......
......@@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config):
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_ocp_mx(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
......@@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
assert rocm_aiter_ops.is_fused_moe_enabled(), (
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
params_dtype = torch.uint32
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Per-tensor fp8 weight scales
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Per-channel int4 weight scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_2 = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero."
for expert_id in range(layer.local_num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
)
layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= (
int4_rescale
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
# GEMM, and weight_scale post
for expert_id in range(layer.local_num_experts):
layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id]
def get_fused_moe_quant_config(self, layer):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale_2,
w2_scale=layer.w2_weight_scale_2,
per_out_ch_quant=True,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__(
self,
......
......@@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
global_num_experts: int,
......@@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
......@@ -301,18 +305,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
......@@ -364,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch.bfloat16
).view(torch.int16)
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
......@@ -380,18 +384,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
......
......@@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
layer.orig_dtype, layer.weight
)
if should_use_deepgemm:
scale_attr = (
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=layer.weight_scale_inv.data,
ws=getattr(layer, scale_attr).data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
replace_parameter(layer, scale_attr, dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
......
......@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
# Check if use_flashinfer is already set
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False
cache = self._compute_cos_sin_cache()
if not self.use_flashinfer:
......
......@@ -6,6 +6,7 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .base import RotaryEmbeddingBase
from .common import (
......@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.use_flashinfer = (
self.enabled()
and dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and has_flashinfer()
and head_size in [64, 128, 256, 512]
)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
......@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
else:
return self.forward_native(positions, query, key, offsets)
......@@ -23,6 +23,7 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import ModelConfig
......@@ -448,12 +449,31 @@ def download_weights_from_hf(
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
# unnecessary files (e.g., from subdirectories like "original/").
index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
if "*.safetensors" in allow_patterns and index_file in file_list:
index_path = hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
)
with open(index_path) as f:
weight_map = json.load(f)["weight_map"]
if weight_map:
# Extra [] so that weight_map files are treated as a
# single allow_pattern in the loop below
allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item]
else:
allow_patterns = ["*.safetensors"]
else:
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
if fnmatch.filter(file_list, pattern):
allow_patterns = [pattern]
break
except Exception as e:
logger.warning(
"Failed to get file list for '%s'. Trying each pattern in "
......@@ -480,6 +500,9 @@ def download_weights_from_hf(
)
# If we have downloaded weights for this allow_pattern,
# we don't need to check the rest.
# allow_pattern can be a list (from weight_map) or str (glob)
if isinstance(allow_pattern, list):
break
if any(Path(hf_folder).glob(allow_pattern)):
break
time_taken = time.perf_counter() - start_time
......
......@@ -8,7 +8,7 @@ from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
......
......@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
......
......@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
......@@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
......
......@@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig,
)
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
......@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
......@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
......@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
......@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers
......
......@@ -308,12 +308,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
# TODO(tdoublep): remove once cascade attention is supported
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."
)
model_config.disable_cascade_attn = True
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
......
......@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers
......
......@@ -141,6 +141,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
......
......@@ -837,7 +837,11 @@ class Indexer(nn.Module):
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(
hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj"
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.softmax_scale = self.head_dim**-0.5
......
......@@ -38,7 +38,10 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
......@@ -463,12 +466,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping
)
......@@ -496,7 +507,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
......
......@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment