Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Requires CPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
......
......@@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
......
......@@ -4,34 +4,53 @@
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return False, "Only symmetric input is supported."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
......@@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)
assert x_zp is None, "Triton kernel only supports symmetric quantization"
return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)
......@@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"TPU platform does have a concept of compute capability, "
"this method should not be called."
)
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
......
......@@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
......@@ -80,6 +81,7 @@ from vllm.utils.flashinfer import (
has_flashinfer,
has_flashinfer_moe,
)
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
......@@ -186,7 +188,24 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if len(self.exclude_modules) > 0:
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
# This is a workaround for the weights remapping issue:
# https://github.com/vllm-project/vllm/issues/28072
# Right now, the Nvidia ModelOpt library use just one wildcard pattern:
# module_path*
# It gets applied if the whole tree of modules rooted at module_path
# is not quantized. Here we replace such pattern by 2 patterns that are
# collectively equivalent to the original pattern:
# module_path
# module_path.*
new_exclude_modules = []
for exclude in self.exclude_modules:
if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".":
new_exclude_modules.append(exclude[:-1])
new_exclude_modules.append(exclude[:-1] + ".*")
else:
new_exclude_modules.append(exclude)
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
@staticmethod
def get_config_filenames() -> list[str]:
......@@ -606,6 +625,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
if self.flashinfer_moe_backend is not None:
self._maybe_pad_intermediate_for_flashinfer(layer)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
......@@ -683,6 +705,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
return
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = layer.w2_weight.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return
logger.info(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
)
up_mult = 2 if self.moe.is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w12 along its intermediate dimension.
w13 = layer.w13_weight.data
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
padded_w13[:, : w13.shape[1], :] = w13
layer.w13_weight.data = padded_w13
w2 = layer.w2_weight.data
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
layer.w2_weight.data = padded_w2
if hasattr(layer, "intermediate_size_per_partition"):
layer.intermediate_size_per_partition = padded_intermediate
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
......@@ -1325,7 +1391,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
# Common processing for input scales and alphas
......@@ -1482,6 +1548,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
a2_gscale=layer.w2_input_scale_quant,
)
@property
def supports_eplb(self) -> bool:
return True
def apply(
self,
layer: FusedMoE,
......@@ -1500,11 +1570,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not layer.enable_eplb
):
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
......@@ -1522,6 +1589,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits=router_logits,
)
# EPLB path
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
if self.use_marlin:
return fused_marlin_moe(
x,
......
......@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
......@@ -162,6 +165,8 @@ class MoeWNA16Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
if isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
# Avoid circular import
......
......@@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
return Mxfp4Backend.SM90_FI_MXFP4_BF16
elif (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
elif (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability(100) and has_flashinfer():
elif current_platform.is_device_capability_family(100) and has_flashinfer():
logger.info_once(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
......@@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)
return Mxfp4Backend.SM100_FI_MXFP4_BF16
elif (
current_platform.is_device_capability(100)
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
) and not has_flashinfer():
logger.warning_once(
......
......@@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)
and current_platform.is_device_capability_family(100)
)
......@@ -331,3 +331,82 @@ def flashinfer_trtllm_fp4_moe(
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out
......@@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8(
assert quant_config is not None
# Construct modular kernel with block-scale support when requested.
parallel_config = getattr(
getattr(layer, "vllm_config", None),
"parallel_config",
None,
)
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
......@@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
parallel_config=parallel_config,
moe_parallel_config=layer.moe_parallel_config,
)
return fused_experts(
......@@ -290,7 +285,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
if flashinfer_moe_backend in backend_map:
if (
flashinfer_moe_backend == "latency"
and not current_platform.is_device_capability(100)
and not current_platform.is_device_capability_family(100)
):
logger.info_once(
"Flashinfer TRTLLM MOE backend is only supported on "
......
......@@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability(100)
self.is_blackwell = current_platform.is_device_capability_family(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations.
......@@ -762,9 +762,12 @@ def per_token_group_quant_fp8(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
assert out_q is None or out_q.shape == x.shape
x_q = out_q
......
......@@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
mx_axis=1, num_warps=num_warps
)
)
if current_platform.is_cuda() and current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
if current_platform.is_cuda():
if current_platform.is_device_capability(90):
constraints = {
"split_k": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
elif current_platform.is_device_capability_family(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
......
......@@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None,
......@@ -54,12 +53,15 @@ def get_rope(
else:
dual_chunk_attention_args = None
partial_rotary_factor = 1.0
if rope_parameters is not None:
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
rope_parameters = rope_parameters or {}
base = rope_parameters.get("rope_theta", 10000)
scaling_type = rope_parameters.get("rope_type", "default")
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0:
raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0")
rotary_dim = int(head_size * partial_rotary_factor)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
......@@ -72,7 +74,6 @@ def get_rope(
if key in _ROPE_DICT:
return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
......@@ -88,208 +89,201 @@ def get_rope(
dtype,
**extra_kwargs,
)
elif not rope_parameters:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
elif scaling_type == "default":
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "linear":
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
scaling_alpha,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
scaling_factor,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]:
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
......@@ -7,7 +7,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .common import ApplyRotaryEmb
@CustomOp.register("rotary_embedding")
......@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
......@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query_rot = ApplyRotaryEmb.forward_static(
query_rot,
cos,
sin,
is_neox_style,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
......@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key_rot = ApplyRotaryEmb.forward_static(
key_rot,
cos,
sin,
is_neox_style,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......
......@@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
from functools import cache
from importlib.util import find_spec
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
logger = init_logger(__name__)
......@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def apply_rotary_emb_dispatch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
else:
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@cache
def dispatch_rotary_emb_function(
default: Callable[..., torch.Tensor] | None = None,
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
if default is not None:
return default
return apply_rotary_emb_torch
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
......@@ -186,3 +116,155 @@ direct_register_custom_op(
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
)
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
def __init__(
self,
enforce_enable: bool = False,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(enforce_enable)
self.is_neox_style = is_neox_style
self.enable_fp32_compute = enable_fp32_compute
self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
self.apply_rotary_emb_flash_attn = apply_rotary
@staticmethod
def forward_static(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> torch.Tensor:
"""
Args:
x: [batch_size (optional), seq_len, num_heads, head_size]
cos: [seq_len, head_size // 2]
sin: [seq_len, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style.
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
for higher accuracy.
"""
origin_dtype = x.dtype
if enable_fp32_compute:
x = x.float()
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
output = torch.cat((o1, o2), dim=-1)
else:
output = torch.stack((o1, o2), dim=-1).flatten(-2)
if enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_native(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
output = self.forward_static(
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
)
return output
def forward_cuda(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_hip(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
if self.apply_rotary_emb_flash_attn is not None:
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = self.apply_rotary_emb_flash_attn(
x, cos, sin, interleaved=interleaved
).type_as(x)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
else:
# Falling back to PyTorch native implementation.
output = self.forward_native(x, cos, sin)
return output
def extra_repr(self) -> str:
s = f"is_neox_style={self.is_neox_style}"
s += f"enable_fp32_compute={self.enable_fp32_compute}"
return s
......@@ -4,7 +4,6 @@
import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding
......@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......
......@@ -8,7 +8,6 @@ import torch
from vllm.triton_utils import tl, triton
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
......@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......
......@@ -4,7 +4,6 @@
import numpy as np
import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
......@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
dtype,
)
def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
......@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......
......@@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
tokens = getattr(text_config, "classifier_from_token", None)
method = getattr(text_config, "method", None)
def auto_set_score_bias(weights):
for name, weight in weights:
if name == "score.bias":
device = self.score.weight.device
dtype = self.score.weight.dtype
bias = weight.to(device).to(dtype)
self.score.bias = torch.nn.Parameter(bias)
self.score.skip_bias_add = False
else:
yield name, weight
weights = auto_set_score_bias(weights)
if tokens is None and method is None:
return super().load_weights(weights)
else:
......
......@@ -241,9 +241,8 @@ class AfmoeAttention(nn.Module):
if self.is_local_attention:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=config["rope_parameters"],
rope_parameters=config.rope_parameters,
is_neox_style=True,
)
else:
......
......@@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment