Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 80 return 75
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
......
...@@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): ...@@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 80 return 75
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
...@@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
logger.debug_once("Finished shuffling weights for TRT-LLM MOE") logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter( layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False gemm1_weights_fp4_shuffled, requires_grad=False
) )
layer.gemm2_weights_fp4_shuffled = Parameter( layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
gemm2_weights_fp4_shuffled, requires_grad=False layer.w13_weight_scale = Parameter(
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False gemm1_scales_fp4_shuffled, requires_grad=False
) )
layer.gemm2_scales_fp4_shuffled = Parameter( layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False gemm2_scales_fp4_shuffled, requires_grad=False
) )
...@@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False, requires_grad=False,
) )
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
elif self.use_marlin: elif self.use_marlin:
# Marlin processing # Marlin processing
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer)
...@@ -1530,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1530,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2_blockscale_swizzled, requires_grad=False w2_blockscale_swizzled, requires_grad=False
) )
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
a1_gscale = layer.w13_input_scale_quant
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
is_sf_swizzled_layout=False,
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
...@@ -1584,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1584,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=layer.e_score_correction_bias, e_score_correction_bias=layer.e_score_correction_bias,
) )
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts( topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x, hidden_states=x_routing,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -95,12 +95,12 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend: ...@@ -95,12 +95,12 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0) and (9, 0) <= current_platform.get_device_capability() < (11, 0)
) )
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.MARLIN return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.TRITON return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
......
...@@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig): ...@@ -218,6 +218,49 @@ class QuarkConfig(QuantizationConfig):
else: else:
return False return False
def _is_fp8_w4a8(
self,
weight_quant: list[dict[str, Any]] | None,
input_quant: dict[str, Any] | None,
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
if not isinstance(weight_quant, list) or len(weight_quant) != 2:
return False
# Confirm weight scheme is supported
is_w4a8_dtype = (
weight_quant[0].get("dtype") == "fp8_e4m3"
and weight_quant[1].get("dtype") == "int4"
and input_quant.get("dtype") == "fp8_e4m3"
)
is_static_weight = not weight_quant[0].get("is_dynamic") and not weight_quant[
1
].get("is_dynamic")
is_per_tensor_fp8_and_per_channel_int4_weight = (
weight_quant[0].get("qscheme") == "per_tensor"
and weight_quant[1].get("qscheme") == "per_channel"
and weight_quant[1].get("symmetric") is True
and weight_quant[1].get("ch_axis") == 0
)
if not (
is_w4a8_dtype
and is_static_weight
and is_per_tensor_fp8_and_per_channel_int4_weight
):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.get("is_dynamic"):
return True
# Confirm activation scheme is supported.
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
return is_per_tensor_activation
def _is_fp8_w8a8( def _is_fp8_w8a8(
self, self,
weight_quant: dict[str, Any] | None, weight_quant: dict[str, Any] | None,
......
...@@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase): ...@@ -63,8 +63,9 @@ class QuarkMoEMethod(FusedMoEMethodBase):
) )
weight_config = layer_quant_config.get("weight") weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors") input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w4a8(weight_config, input_config):
if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_ocp_mx(weight_config, input_config): elif quant_config._is_ocp_mx(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
...@@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -396,6 +397,161 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
) )
class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
assert rocm_aiter_ops.is_fused_moe_enabled(), (
"W4A8 FP8 MoE requires ROCm AITER fused MoE support."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
params_dtype = torch.uint32
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 8, # INT32 packing for W4
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Per-tensor fp8 weight scales
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Per-channel int4 weight scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale_2 = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
assert torch.all(max_w13_scales != 0), "fp8 weight scale cannot be zero."
for expert_id in range(layer.local_num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
)
layer.w13_weight_scale_2[expert_id][start : start + shard_size] *= (
int4_rescale
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post
# GEMM scaling optimal design - shall apply per-column weight_scale1 before
# GEMM, and weight_scale post
for expert_id in range(layer.local_num_experts):
layer.w13_weight_scale_2[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale_2[expert_id] *= layer.w2_weight_scale[expert_id]
def get_fused_moe_quant_config(self, layer):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale_2,
w2_scale=layer.w2_weight_scale_2,
per_out_ch_quant=True,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__( def __init__(
self, self,
......
...@@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ...@@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
def flashinfer_trtllm_fp4_moe( def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
global_num_experts: int, global_num_experts: int,
...@@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe( ...@@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE from vllm.model_executor.models.llama4 import Llama4MoE
# Quantize input to FP4 # Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant if isinstance(x, tuple):
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( hidden_states_fp4, hidden_states_scale_linear_fp4 = x
x, else:
a1_gscale, # hidden_states is the already quantized
is_sf_swizzled_layout=False, a1_gscale = layer.w13_input_scale_quant
) (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Determine routing method type # Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
...@@ -301,18 +305,14 @@ def flashinfer_trtllm_fp4_moe( ...@@ -301,18 +305,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn torch.float8_e4m3fn
).flatten(), ).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm1_bias=None, gemm1_bias=None,
gemm1_alpha=None, gemm1_alpha=None,
gemm1_beta=None, gemm1_beta=None,
gemm1_clamp_limit=None, gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm2_bias=None, gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data, output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data, output1_scale_gate_scalar=layer.g1_alphas.data,
...@@ -364,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -364,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch.bfloat16 torch.bfloat16
).view(torch.int16) ).view(torch.int16)
# Quantize input to FP4 if isinstance(x, tuple):
a1_gscale = layer.w13_input_scale_quant # Hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( hidden_states_fp4, hidden_states_scale_linear_fp4 = x
x, else:
a1_gscale, # Quantize input to FP4
is_sf_swizzled_layout=False, a1_gscale = layer.w13_input_scale_quant
) (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel # Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
...@@ -380,18 +384,14 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -380,18 +384,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale=hidden_states_scale_linear_fp4.view( hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn torch.float8_e4m3fn
).flatten(), ).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm1_bias=None, gemm1_bias=None,
gemm1_alpha=None, gemm1_alpha=None,
gemm1_beta=None, gemm1_beta=None,
gemm1_clamp_limit=None, gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
torch.float8_e4m3fn
),
gemm2_bias=None, gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data, output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data, output1_scale_gate_scalar=layer.g1_alphas.data,
......
...@@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): ...@@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
layer.orig_dtype, layer.weight layer.orig_dtype, layer.weight
) )
if should_use_deepgemm: if should_use_deepgemm:
scale_attr = (
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data, wq=layer.weight.data,
ws=layer.weight_scale_inv.data, ws=getattr(layer, scale_attr).data,
quant_block_shape=tuple(layer.weight_block_size), quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(), use_e8m0=is_deep_gemm_e8m0_used(),
) )
replace_parameter(layer, "weight", dg_weight) replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, "weight_scale_inv", dg_weight_scale) replace_parameter(layer, scale_attr, dg_weight_scale)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def expert_weight_is_col_major(x: torch.Tensor) -> bool:
......
...@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp): ...@@ -38,7 +38,10 @@ class RotaryEmbeddingBase(CustomOp):
# and current_platform.is_cuda() # and current_platform.is_cuda()
# and has_flashinfer() # and has_flashinfer()
# and self.head_size in [64, 128, 256, 512]) # and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
# Check if use_flashinfer is already set
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
if not self.use_flashinfer: if not self.use_flashinfer:
......
...@@ -6,6 +6,7 @@ import math ...@@ -6,6 +6,7 @@ import math
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .base import RotaryEmbeddingBase from .base import RotaryEmbeddingBase
from .common import ( from .common import (
...@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -56,6 +57,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor * attn_factor
) )
self.use_flashinfer = (
self.enabled()
and dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and has_flashinfer()
and head_size in [64, 128, 256, 512]
)
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
...@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): ...@@ -162,4 +170,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None, key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None, offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets) if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
else:
return self.forward_native(positions, query, key, offsets)
...@@ -23,6 +23,7 @@ import torch ...@@ -23,6 +23,7 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load, load_file, safe_open, save_file from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -448,12 +449,31 @@ def download_weights_from_hf( ...@@ -448,12 +449,31 @@ def download_weights_from_hf(
fs = HfFileSystem() fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision) file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# Use the first pattern found in the HF repo's files. # If downloading safetensors and an index file exists, use the
for pattern in allow_patterns: # specific file names from the index to avoid downloading
matching = fnmatch.filter(file_list, pattern) # unnecessary files (e.g., from subdirectories like "original/").
if len(matching) > 0: index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
allow_patterns = [pattern] if "*.safetensors" in allow_patterns and index_file in file_list:
break index_path = hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
)
with open(index_path) as f:
weight_map = json.load(f)["weight_map"]
if weight_map:
# Extra [] so that weight_map files are treated as a
# single allow_pattern in the loop below
allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item]
else:
allow_patterns = ["*.safetensors"]
else:
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
if fnmatch.filter(file_list, pattern):
allow_patterns = [pattern]
break
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Failed to get file list for '%s'. Trying each pattern in " "Failed to get file list for '%s'. Trying each pattern in "
...@@ -480,6 +500,9 @@ def download_weights_from_hf( ...@@ -480,6 +500,9 @@ def download_weights_from_hf(
) )
# If we have downloaded weights for this allow_pattern, # If we have downloaded weights for this allow_pattern,
# we don't need to check the rest. # we don't need to check the rest.
# allow_pattern can be a list (from weight_map) or str (glob)
if isinstance(allow_pattern, list):
break
if any(Path(hf_folder).glob(allow_pattern)): if any(Path(hf_folder).glob(allow_pattern)):
break break
time_taken = time.perf_counter() - start_time time_taken = time.perf_counter() - start_time
......
...@@ -8,7 +8,7 @@ from collections.abc import Iterable ...@@ -8,7 +8,7 @@ from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module): ...@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )
......
...@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module): ...@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids", "position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute": if self.position_embedding_type != "absolute":
raise ValueError( raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported" "Only 'absolute' position_embedding_type" + " is supported"
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
...@@ -122,7 +122,7 @@ class BlipAttention(nn.Module): ...@@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )
......
...@@ -14,7 +14,8 @@ from transformers import ( ...@@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig, CLIPVisionConfig,
) )
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
...@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module): ...@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module): ...@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = CLIPAttention( self.self_attn = CLIPAttention(
...@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module): ...@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module): ...@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
......
...@@ -308,12 +308,6 @@ class MambaModelConfig(VerifyAndUpdateConfig): ...@@ -308,12 +308,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
if cache_config.mamba_block_size is None: if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len cache_config.mamba_block_size = model_config.max_model_len
# TODO(tdoublep): remove once cascade attention is supported
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."
)
model_config.disable_cascade_attn = True
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod @classmethod
......
...@@ -18,7 +18,7 @@ import torch.nn as nn ...@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module): ...@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers
......
...@@ -141,6 +141,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -141,6 +141,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
......
...@@ -837,7 +837,11 @@ class Indexer(nn.Module): ...@@ -837,7 +837,11 @@ class Indexer(nn.Module):
) )
self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear( self.weights_proj = ReplicatedLinear(
hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj" hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
) )
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
......
...@@ -38,7 +38,10 @@ from vllm.model_executor.layers.linear import ( ...@@ -38,7 +38,10 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
...@@ -463,12 +466,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -463,12 +466,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma3Model( self.model = Gemma3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping config.vocab_size, soft_cap=config.final_logit_softcapping
) )
...@@ -496,7 +507,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -496,7 +507,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.logits_processor(self.model.embed_tokens, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
......
...@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType ...@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module): ...@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale self.num_heads_per_rank, self.head_dim, self.scale
) )
self.output_dropout = torch.nn.Dropout(config.dropout_prob) self.output_dropout = torch.nn.Dropout(config.dropout_prob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment