Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
701 additions
and
314 deletions
+701
-314
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
...del_executor/layers/quantization/kernels/scaled_mm/cpu.py
+6
-5
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
...executor/layers/quantization/kernels/scaled_mm/cutlass.py
+12
-5
vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
..._executor/layers/quantization/kernels/scaled_mm/triton.py
+46
-17
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
...del_executor/layers/quantization/kernels/scaled_mm/xla.py
+6
-5
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+87
-6
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+5
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+4
-4
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+80
-1
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+2
-7
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+6
-3
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+12
-6
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+160
-166
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+17
-3
vllm/model_executor/layers/rotary_embedding/common.py
vllm/model_executor/layers/rotary_embedding/common.py
+153
-71
vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
...model_executor/layers/rotary_embedding/ernie45_vl_rope.py
+10
-3
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+20
-5
vllm/model_executor/layers/rotary_embedding/xdrope.py
vllm/model_executor/layers/rotary_embedding/xdrope.py
+62
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+12
-0
vllm/model_executor/models/afmoe.py
vllm/model_executor/models/afmoe.py
+1
-2
vllm/model_executor/models/apertus.py
vllm/model_executor/models/apertus.py
+0
-1
No files found.
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
View file @
a3f8d5dd
...
...
@@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class
CPUScaledMMLinearKernel
(
ScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cpu
():
return
False
,
"Requires CPU."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cpu
():
return
False
,
"CPUScaledMM requires running on CPU."
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
View file @
a3f8d5dd
...
...
@@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class
CutlassScaledMMLinearKernel
(
ScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"Requires CUDA."
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
if
compute_capability
is
not
None
and
compute_capability
<
75
:
return
False
,
f
"requires capability 75, got
{
compute_capability
}
"
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"CutlassScaledMM requires running on CUDA."
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py
View file @
a3f8d5dd
...
...
@@ -4,34 +4,53 @@
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm
import
(
# noqa: E501
triton_scaled_mm
,
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
.cutlass
import
CutlassScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
class
TritonScaledMMLinearKernel
(
Cutlass
ScaledMMLinearKernel
):
class
TritonScaledMMLinearKernel
(
ScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
current_platform
.
is_cuda_alike
():
return
True
,
None
return
False
,
"Requires ROCm or CUDA."
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
current_platform
.
is_cpu
():
return
(
False
,
"TritonScaledMMLinearKernel requires Triton which is not "
+
"currently supported on CPU."
,
)
if
not
c
.
input_symmetric
:
return
(
False
,
"TritonScaledMMLinearKernel only supports symmetric "
+
"quantization."
,
)
return
False
,
"Only symmetric input is supported."
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
weight
=
getattr
(
layer
,
self
.
w_q_name
)
replace_parameter
(
layer
,
self
.
w_q_name
,
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
)
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
replace_parameter
(
layer
,
self
.
i_s_name
,
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
def
apply_weights
(
self
,
...
...
@@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
super
().
apply_weights
(
layer
,
x
,
bias
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_weight_params
(
layer
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
.
contiguous
(),
i_s
,
i_zp
,
symmetric
=
True
)
assert
x_zp
is
None
,
"Triton kernel only supports symmetric quantization"
return
triton_scaled_mm
(
x_q
,
w_q
,
scale_a
=
x_s
,
scale_b
=
w_s
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
View file @
a3f8d5dd
...
...
@@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
class
XLAScaledMMLinearKernel
(
ScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"TPU platform does have a concept of compute capability, "
"this method should not be called."
)
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_tpu
():
return
False
,
"Requires TPU."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
a3f8d5dd
...
...
@@ -38,6 +38,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
flashinfer_trtllm_fp4_moe
,
flashinfer_trtllm_fp4_routed_moe
,
prepare_static_weights_for_trtllm_fp4_moe
,
reorder_w1w3_to_w3w1
,
select_nvfp4_gemm_impl
,
...
...
@@ -80,6 +81,7 @@ from vllm.utils.flashinfer import (
has_flashinfer
,
has_flashinfer_moe
,
)
from
vllm.utils.math_utils
import
round_up
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
...
...
@@ -186,7 +188,24 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
if
len
(
self
.
exclude_modules
)
>
0
:
self
.
exclude_modules
=
hf_to_vllm_mapper
.
apply_list
(
self
.
exclude_modules
)
# This is a workaround for the weights remapping issue:
# https://github.com/vllm-project/vllm/issues/28072
# Right now, the Nvidia ModelOpt library use just one wildcard pattern:
# module_path*
# It gets applied if the whole tree of modules rooted at module_path
# is not quantized. Here we replace such pattern by 2 patterns that are
# collectively equivalent to the original pattern:
# module_path
# module_path.*
new_exclude_modules
=
[]
for
exclude
in
self
.
exclude_modules
:
if
len
(
exclude
)
>=
2
and
exclude
[
-
1
]
==
"*"
and
exclude
[
-
2
]
!=
"."
:
new_exclude_modules
.
append
(
exclude
[:
-
1
])
new_exclude_modules
.
append
(
exclude
[:
-
1
]
+
".*"
)
else
:
new_exclude_modules
.
append
(
exclude
)
self
.
exclude_modules
=
hf_to_vllm_mapper
.
apply_list
(
new_exclude_modules
)
@
staticmethod
def
get_config_filenames
()
->
list
[
str
]:
...
...
@@ -606,6 +625,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
if
self
.
flashinfer_moe_backend
is
not
None
:
self
.
_maybe_pad_intermediate_for_flashinfer
(
layer
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
...
...
@@ -683,6 +705,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
register_moe_scaling_factors
(
layer
)
def
_maybe_pad_intermediate_for_flashinfer
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
if
not
hasattr
(
layer
,
"w13_weight"
)
or
not
hasattr
(
layer
,
"w2_weight"
):
return
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts
,
hidden_size
,
intermediate
=
layer
.
w2_weight
.
shape
min_alignment
=
16
padded_intermediate
=
round_up
(
intermediate
,
min_alignment
)
if
padded_intermediate
==
intermediate
:
return
logger
.
info
(
"Padding intermediate size from %d to %d for up/down projection weights."
,
intermediate
,
padded_intermediate
,
)
up_mult
=
2
if
self
.
moe
.
is_act_and_mul
else
1
padded_gate_up_dim
=
up_mult
*
padded_intermediate
# Pad w13 and w12 along its intermediate dimension.
w13
=
layer
.
w13_weight
.
data
padded_w13
=
w13
.
new_zeros
((
num_experts
,
padded_gate_up_dim
,
hidden_size
))
padded_w13
[:,
:
w13
.
shape
[
1
],
:]
=
w13
layer
.
w13_weight
.
data
=
padded_w13
w2
=
layer
.
w2_weight
.
data
padded_w2
=
w2
.
new_zeros
((
num_experts
,
hidden_size
,
padded_intermediate
))
padded_w2
[:,
:,
:
intermediate
]
=
w2
layer
.
w2_weight
.
data
=
padded_w2
if
hasattr
(
layer
,
"intermediate_size_per_partition"
):
layer
.
intermediate_size_per_partition
=
padded_intermediate
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
...
...
@@ -1325,7 +1391,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Accuracy may be affected."
)
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
.
contiguous
()
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
# Common processing for input scales and alphas
...
...
@@ -1482,6 +1548,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
a2_gscale
=
layer
.
w2_input_scale_quant
,
)
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
def
apply
(
self
,
layer
:
FusedMoE
,
...
...
@@ -1500,11 +1570,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if
(
self
.
allow_flashinfer
and
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
and
not
layer
.
enable_eplb
):
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return
flashinfer_trtllm_fp4_moe
(
layer
=
layer
,
x
=
x
,
...
...
@@ -1522,6 +1589,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits
=
router_logits
,
)
# EPLB path
if
(
self
.
allow_flashinfer
and
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
):
return
flashinfer_trtllm_fp4_routed_moe
(
layer
=
layer
,
x
=
x
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
top_k
=
layer
.
top_k
,
global_num_experts
=
layer
.
global_num_experts
,
)
if
self
.
use_marlin
:
return
fused_marlin_moe
(
x
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
a3f8d5dd
...
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -162,6 +165,8 @@ class MoeWNA16Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
if
isinstance
(
layer
,
FusedMoE
):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
LinearBase
):
# Avoid circular import
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
a3f8d5dd
...
...
@@ -118,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
logger
.
info_once
(
"Using FlashInfer MXFP4 BF16 backend for SM90"
)
return
Mxfp4Backend
.
SM90_FI_MXFP4_BF16
elif
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
and
has_flashinfer
()
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
logger
.
info_once
(
"Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100"
)
return
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_CUTLASS
elif
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
and
has_flashinfer
()
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
elif
current_platform
.
is_device_capability
(
100
)
and
has_flashinfer
():
elif
current_platform
.
is_device_capability
_family
(
100
)
and
has_flashinfer
():
logger
.
info_once
(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
...
...
@@ -139,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
)
return
Mxfp4Backend
.
SM100_FI_MXFP4_BF16
elif
(
current_platform
.
is_device_capability
(
100
)
current_platform
.
is_device_capability
_family
(
100
)
or
current_platform
.
is_device_capability
(
90
)
)
and
not
has_flashinfer
():
logger
.
warning_once
(
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
a3f8d5dd
...
...
@@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutedsl_grouped_gemm_nt_masked
()
and
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
_family
(
100
)
)
...
...
@@ -331,3 +331,82 @@ def flashinfer_trtllm_fp4_moe(
)[
0
]
return
out
def
flashinfer_trtllm_fp4_routed_moe
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
top_k
:
int
,
global_num_experts
:
int
,
)
->
torch
.
Tensor
:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import
flashinfer
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor
=
(
topk_ids
.
to
(
torch
.
int32
)
<<
16
)
|
topk_weights
.
to
(
torch
.
bfloat16
).
view
(
torch
.
int16
)
# Quantize input to FP4
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
x
,
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out
=
flashinfer
.
fused_moe
.
trtllm_fp4_block_scale_routed_moe
(
topk_ids
=
packed_tensor
,
routing_bias
=
None
,
hidden_states
=
hidden_states_fp4
,
hidden_states_scale
=
hidden_states_scale_linear_fp4
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
layer
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
layer
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_bias
=
None
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
layer
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
layer
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_bias
=
None
,
output1_scale_scalar
=
layer
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
layer
.
g1_alphas
.
data
,
output2_scale_scalar
=
layer
.
g2_alphas
.
data
,
num_experts
=
global_num_experts
,
top_k
=
top_k
,
n_group
=
0
,
topk_group
=
0
,
intermediate_size
=
layer
.
intermediate_size_per_partition
,
local_expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
routed_scaling_factor
=
None
,
tile_tokens_dim
=
None
,
routing_method_type
=
1
,
do_finalize
=
True
,
)[
0
]
return
out
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
a3f8d5dd
...
...
@@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8(
assert
quant_config
is
not
None
# Construct modular kernel with block-scale support when requested.
parallel_config
=
getattr
(
getattr
(
layer
,
"vllm_config"
,
None
),
"parallel_config"
,
None
,
)
fused_experts
=
mk
.
FusedMoEModularKernel
(
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
moe
=
moe
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
...
...
@@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8(
out_dtype
=
hidden_states
.
dtype
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
),
parallel_config
=
parallel_config
,
moe_
parallel_config
=
layer
.
moe_
parallel_config
,
)
return
fused_experts
(
...
...
@@ -290,7 +285,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
if
flashinfer_moe_backend
in
backend_map
:
if
(
flashinfer_moe_backend
==
"latency"
and
not
current_platform
.
is_device_capability
(
100
)
and
not
current_platform
.
is_device_capability
_family
(
100
)
):
logger
.
info_once
(
"Flashinfer TRTLLM MOE backend is only supported on "
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
a3f8d5dd
...
...
@@ -247,7 +247,7 @@ class W8A8BlockFp8LinearOp:
self
.
act_quant_group_shape
=
act_quant_group_shape
self
.
is_deep_gemm_supported
=
is_deep_gemm_supported
()
self
.
is_hopper
=
current_platform
.
is_device_capability
(
90
)
self
.
is_blackwell
=
current_platform
.
is_device_capability
(
100
)
self
.
is_blackwell
=
current_platform
.
is_device_capability
_family
(
100
)
self
.
use_deep_gemm_e8m0
=
is_deep_gemm_e8m0_used
()
# Get the correct blockscale mul and input quant operations.
...
...
@@ -762,9 +762,12 @@ def per_token_group_quant_fp8(
)
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
fp8_min
=
-
224.0
if
current_platform
.
is_fp8_fnuz
()
else
finfo
.
min
fp8_max
=
224.0
if
current_platform
.
is_fp8_fnuz
()
else
finfo
.
max
assert
out_q
is
None
or
out_q
.
shape
==
x
.
shape
x_q
=
out_q
...
...
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
View file @
a3f8d5dd
...
...
@@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
mx_axis
=
1
,
num_warps
=
num_warps
)
)
if
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
):
constraints
=
{
"is_persistent"
:
True
,
"epilogue_subtile"
:
1
,
}
opt_flags
.
update_opt_flags_constraints
(
constraints
)
if
current_platform
.
is_cuda
():
if
current_platform
.
is_device_capability
(
90
):
constraints
=
{
"split_k"
:
1
,
}
opt_flags
.
update_opt_flags_constraints
(
constraints
)
elif
current_platform
.
is_device_capability_family
(
100
):
constraints
=
{
"is_persistent"
:
True
,
"epilogue_subtile"
:
1
,
}
opt_flags
.
update_opt_flags_constraints
(
constraints
)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor
=
quant_tensor
.
transpose
(
-
2
,
-
1
)
scale
=
scale
.
transpose
(
-
2
,
-
1
)
...
...
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
a3f8d5dd
...
...
@@ -25,7 +25,6 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
is_neox_style
:
bool
=
True
,
rope_parameters
:
dict
[
str
,
Any
]
|
None
=
None
,
...
...
@@ -54,12 +53,15 @@ def get_rope(
else
:
dual_chunk_attention_args
=
None
partial_rotary_factor
=
1.0
if
rope_parameters
is
not
None
:
partial_rotary_factor
=
rope_parameters
.
get
(
"partial_rotary_factor"
,
1.0
)
rope_parameters
=
rope_parameters
or
{}
base
=
rope_parameters
.
get
(
"rope_theta"
,
10000
)
scaling_type
=
rope_parameters
.
get
(
"rope_type"
,
"default"
)
partial_rotary_factor
=
rope_parameters
.
get
(
"partial_rotary_factor"
,
1.0
)
if
partial_rotary_factor
<=
0.0
or
partial_rotary_factor
>
1.0
:
raise
ValueError
(
f
"
{
partial_rotary_factor
=
}
must be between 0.0 and 1.0"
)
rotary_dim
=
int
(
head_size
*
partial_rotary_factor
)
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
...
...
@@ -72,7 +74,6 @@ def get_rope(
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
base
=
rope_parameters
[
"rope_theta"
]
if
rope_parameters
else
10000
if
dual_chunk_attention_config
is
not
None
:
extra_kwargs
=
{
k
:
v
...
...
@@ -88,208 +89,201 @@ def get_rope(
dtype
,
**
extra_kwargs
,
)
elif
not
rope_parameters
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
else
:
scaling_type
=
rope_parameters
[
"rope_type"
]
if
scaling_type
==
"llama3"
:
scaling_factor
=
rope_parameters
[
"factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
high_freq_factor
=
rope_parameters
[
"high_freq_factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_parameters
:
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
,
mrope_section
=
rope_parameters
[
"mrope_section"
],
mrope_interleaved
=
rope_parameters
.
get
(
"mrope_interleaved"
,
False
),
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_parameters
:
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_parameters
[
"mrope_section"
],
mrope_interleaved
=
rope_parameters
.
get
(
"mrope_interleaved"
,
False
),
)
else
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"linear"
:
scaling_factor
=
rope_parameters
[
"factor"
]
rotary_emb
=
LinearScalingRotaryEmbedding
(
else
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
elif
scaling_type
==
"ntk"
:
scaling_factor
=
rope_parameters
[
"factor"
]
mixed_b
=
rope_parameters
.
get
(
"mixed_b"
)
rotary_emb
=
NTKScalingRotaryEmbedding
(
elif
scaling_type
==
"llama3"
:
scaling_factor
=
rope_parameters
[
"factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
high_freq_factor
=
rope_parameters
[
"high_freq_factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
,
)
elif
scaling_type
==
"mllama4"
:
rotary_emb
=
Llama4VisionRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
elif
scaling_type
==
"linear"
:
scaling_factor
=
rope_parameters
[
"factor"
]
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
elif
scaling_type
==
"ntk"
:
scaling_factor
=
rope_parameters
[
"factor"
]
mixed_b
=
rope_parameters
.
get
(
"mixed_b"
)
rotary_emb
=
NTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
mixed_b
,
)
elif
scaling_type
==
"dynamic"
:
if
"alpha"
in
rope_parameters
:
scaling_alpha
=
rope_parameters
[
"alpha"
]
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_
factor
,
scaling_
alpha
,
dtype
,
mixed_b
,
)
elif
scaling_type
==
"dynamic"
:
if
"alpha"
in
rope_parameters
:
scaling_alpha
=
rope_parameters
[
"alpha"
]
rotary_emb
=
DynamicNTKAlphaRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_alpha
,
dtype
,
)
elif
"factor"
in
rope_parameters
:
scaling_factor
=
rope_parameters
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
else
:
raise
ValueError
(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif
scaling_type
==
"xdrope"
:
scaling_alpha
=
rope_parameters
[
"alpha"
]
rotary_emb
=
XDRotaryEmbedding
(
elif
"factor"
in
rope_parameters
:
scaling_factor
=
rope_parameters
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_
alpha
,
scaling_
factor
,
dtype
,
xdrope_section
=
rope_parameters
[
"xdrope_section"
],
)
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_parameters
[
"factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"apply_yarn_scaling"
,
"truncate"
,
)
}
if
"mrope_section"
in
rope_parameters
:
extra_kwargs
.
pop
(
"apply_yarn_scaling"
,
None
)
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_parameters
[
"mrope_section"
],
mrope_interleaved
=
rope_parameters
.
get
(
"mrope_interleaved"
,
False
),
scaling_factor
=
scaling_factor
,
**
extra_kwargs
,
)
else
:
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
elif
scaling_type
in
[
"deepseek_yarn"
,
"deepseek_llama_scaling"
]:
scaling_factor
=
rope_parameters
[
"factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
,
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
else
:
raise
ValueError
(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif
scaling_type
==
"xdrope"
:
scaling_alpha
=
rope_parameters
[
"alpha"
]
rotary_emb
=
XDRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_alpha
,
dtype
,
xdrope_section
=
rope_parameters
[
"xdrope_section"
],
)
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_parameters
[
"factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"apply_yarn_scaling"
,
"truncate"
,
)
}
if
"mrope_section"
in
rope_parameters
:
extra_kwargs
.
pop
(
"apply_yarn_scaling"
,
None
)
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
mrope_section
=
rope_parameters
[
"mrope_section"
],
mrope_interleaved
=
rope_parameters
.
get
(
"mrope_interleaved"
,
False
),
scaling_factor
=
scaling_factor
,
**
extra_kwargs
,
)
elif
scaling_type
==
"longrope"
:
short_factor
=
rope_parameters
[
"short_factor"
]
long_factor
=
rope_parameters
[
"long_factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3LongRoPEScaledRotaryEmbedding
(
else
:
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
,
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
elif
scaling_type
in
[
"deepseek_yarn"
,
"deepseek_llama_scaling"
]:
scaling_factor
=
rope_parameters
[
"factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
,
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
elif
scaling_type
==
"longrope"
:
short_factor
=
rope_parameters
[
"short_factor"
]
long_factor
=
rope_parameters
[
"long_factor"
]
original_max_position
=
rope_parameters
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_parameters
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3LongRoPEScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
,
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
return
rotary_emb
vllm/model_executor/layers/rotary_embedding/base.py
View file @
a3f8d5dd
...
...
@@ -7,7 +7,7 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.model_executor.custom_op
import
CustomOp
from
.common
import
a
pply
_r
otary
_emb_torch
from
.common
import
A
pply
R
otary
Emb
@
CustomOp
.
register
(
"rotary_embedding"
)
...
...
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
rocm_aiter_ops
.
is_triton_rotary_embed_enabled
()
)
self
.
apply_rotary_emb
=
ApplyRotaryEmb
(
is_neox_style
=
self
.
is_neox_style
,
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
...
...
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query
=
query
.
view
(
num_tokens
,
-
1
,
head_size
)
query_rot
=
query
[...,
:
rotary_dim
]
query_pass
=
query
[...,
rotary_dim
:]
query_rot
=
apply_rotary_emb_torch
(
query_rot
,
cos
,
sin
,
is_neox_style
)
query_rot
=
ApplyRotaryEmb
.
forward_static
(
query_rot
,
cos
,
sin
,
is_neox_style
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
# key may be None in some cases, e.g. cross-layer KV sharing
...
...
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key
=
key
.
view
(
num_tokens
,
-
1
,
head_size
)
key_rot
=
key
[...,
:
rotary_dim
]
key_pass
=
key
[...,
rotary_dim
:]
key_rot
=
apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
is_neox_style
)
key_rot
=
ApplyRotaryEmb
.
forward_static
(
key_rot
,
cos
,
sin
,
is_neox_style
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
vllm/model_executor/layers/rotary_embedding/common.py
View file @
a3f8d5dd
...
...
@@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Callable
from
functools
import
cache
from
importlib.util
import
find_spec
import
torch
from
vllm.logger
import
init_logger
from
vllm.
platforms
import
current_platform
from
vllm.
model_executor.custom_op
import
CustomOp
from
vllm.utils.torch_utils
import
direct_register_custom_op
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
logger
=
init_logger
(
__name__
)
...
...
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
def
apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
def
apply_rotary_emb_dispatch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if
current_platform
.
is_cuda
():
return
apply_rotary_emb
(
x
.
unsqueeze
(
0
),
cos
,
sin
,
not
is_neox_style
).
squeeze
(
0
)
else
:
return
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
is_neox_style
)
@
cache
def
dispatch_rotary_emb_function
(
default
:
Callable
[...,
torch
.
Tensor
]
|
None
=
None
,
)
->
Callable
[...,
torch
.
Tensor
]:
if
current_platform
.
is_cuda
():
return
apply_rotary_emb
# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if
current_platform
.
is_rocm
()
and
not
torch
.
compiler
.
is_compiling
():
if
find_spec
(
"flash_attn"
)
is
not
None
:
from
flash_attn.ops.triton.rotary
import
apply_rotary
return
apply_rotary
else
:
logger
.
warning
(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
if
default
is
not
None
:
return
default
return
apply_rotary_emb_torch
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def
yarn_find_correction_dim
(
...
...
@@ -186,3 +116,155 @@ direct_register_custom_op(
mutates_args
=
[
"query"
,
"key"
],
# These tensors are modified in-place
fake_impl
=
_flashinfer_rotary_embedding_fake
,
)
@
CustomOp
.
register
(
"apply_rotary_emb"
)
class
ApplyRotaryEmb
(
CustomOp
):
def
__init__
(
self
,
enforce_enable
:
bool
=
False
,
is_neox_style
:
bool
=
True
,
enable_fp32_compute
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
enforce_enable
)
self
.
is_neox_style
=
is_neox_style
self
.
enable_fp32_compute
=
enable_fp32_compute
self
.
apply_rotary_emb_flash_attn
=
None
if
find_spec
(
"flash_attn"
)
is
not
None
:
from
flash_attn.ops.triton.rotary
import
apply_rotary
self
.
apply_rotary_emb_flash_attn
=
apply_rotary
@
staticmethod
def
forward_static
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
=
True
,
enable_fp32_compute
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
x: [batch_size (optional), seq_len, num_heads, head_size]
cos: [seq_len, head_size // 2]
sin: [seq_len, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style.
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
for higher accuracy.
"""
origin_dtype
=
x
.
dtype
if
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
output
=
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
output
=
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
if
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
return
output
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
=
self
.
forward_static
(
x
,
cos
,
sin
,
self
.
is_neox_style
,
self
.
enable_fp32_compute
)
return
output
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
origin_dtype
=
x
.
dtype
if
self
.
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
float
()
sin
=
sin
.
float
()
origin_shape
=
x
.
shape
if
len
(
origin_shape
)
==
3
:
# x: [seq_len, num_heads, head_size]
x
=
x
.
unsqueeze
(
0
)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved
=
not
self
.
is_neox_style
output
=
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
)
if
len
(
origin_shape
)
==
3
:
output
=
output
.
squeeze
(
0
)
if
self
.
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
return
output
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
self
.
apply_rotary_emb_flash_attn
is
not
None
:
origin_dtype
=
x
.
dtype
if
self
.
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
float
()
sin
=
sin
.
float
()
origin_shape
=
x
.
shape
if
len
(
origin_shape
)
==
3
:
# x: [seq_len, num_heads, head_size]
x
=
x
.
unsqueeze
(
0
)
"""
Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved
=
not
self
.
is_neox_style
output
=
self
.
apply_rotary_emb_flash_attn
(
x
,
cos
,
sin
,
interleaved
=
interleaved
).
type_as
(
x
)
if
len
(
origin_shape
)
==
3
:
output
=
output
.
squeeze
(
0
)
if
self
.
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
else
:
# Falling back to PyTorch native implementation.
output
=
self
.
forward_native
(
x
,
cos
,
sin
)
return
output
def
extra_repr
(
self
)
->
str
:
s
=
f
"is_neox_style=
{
self
.
is_neox_style
}
"
s
+=
f
"enable_fp32_compute=
{
self
.
enable_fp32_compute
}
"
return
s
vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
View file @
a3f8d5dd
...
...
@@ -4,7 +4,6 @@
import
torch
from
.common
import
apply_rotary_emb_dispatch
from
.mrope
import
MRotaryEmbedding
...
...
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
self
.
apply_rotary_emb
.
forward_native
(
query_rot
,
cos
,
sin
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
self
.
apply_rotary_emb
.
forward_native
(
key_rot
,
cos
,
sin
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
a3f8d5dd
...
...
@@ -8,7 +8,6 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
.base
import
RotaryEmbeddingBase
from
.common
import
apply_rotary_emb_dispatch
from
.yarn_scaling_rope
import
YaRNScalingRotaryEmbedding
,
yarn_get_mscale
...
...
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
self
.
apply_rotary_emb
.
forward_native
(
query_rot
,
cos
,
sin
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
self
.
apply_rotary_emb
.
forward_native
(
key_rot
,
cos
,
sin
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
self
.
apply_rotary_emb
(
query_rot
,
cos
,
sin
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
self
.
apply_rotary_emb
(
key_rot
,
cos
,
sin
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
vllm/model_executor/layers/rotary_embedding/xdrope.py
View file @
a3f8d5dd
...
...
@@ -4,7 +4,6 @@
import
numpy
as
np
import
torch
from
.common
import
apply_rotary_emb_dispatch
from
.dynamic_ntk_alpha_rope
import
DynamicNTKAlphaRotaryEmbedding
...
...
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
dtype
,
)
def
forward
(
def
forward
_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
self
.
apply_rotary_emb
.
forward_native
(
query_rot
,
cos
,
sin
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
self
.
apply_rotary_emb
.
forward_native
(
key_rot
,
cos
,
sin
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert
positions
.
ndim
==
2
assert
key
is
not
None
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
xdrope_section
,
dim
=-
1
))],
dim
=-
1
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
xdrope_section
,
dim
=-
1
))],
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
self
.
apply_rotary_emb
(
query_rot
,
cos
,
sin
,
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
self
.
apply_rotary_emb
(
key_rot
,
cos
,
sin
,
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
vllm/model_executor/models/adapters.py
View file @
a3f8d5dd
...
...
@@ -337,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T:
tokens
=
getattr
(
text_config
,
"classifier_from_token"
,
None
)
method
=
getattr
(
text_config
,
"method"
,
None
)
def
auto_set_score_bias
(
weights
):
for
name
,
weight
in
weights
:
if
name
==
"score.bias"
:
device
=
self
.
score
.
weight
.
device
dtype
=
self
.
score
.
weight
.
dtype
bias
=
weight
.
to
(
device
).
to
(
dtype
)
self
.
score
.
bias
=
torch
.
nn
.
Parameter
(
bias
)
self
.
score
.
skip_bias_add
=
False
else
:
yield
name
,
weight
weights
=
auto_set_score_bias
(
weights
)
if
tokens
is
None
and
method
is
None
:
return
super
().
load_weights
(
weights
)
else
:
...
...
vllm/model_executor/models/afmoe.py
View file @
a3f8d5dd
...
...
@@ -241,9 +241,8 @@ class AfmoeAttention(nn.Module):
if
self
.
is_local_attention
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
config
[
"
rope_parameters
"
]
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
True
,
)
else
:
...
...
vllm/model_executor/models/apertus.py
View file @
a3f8d5dd
...
...
@@ -226,7 +226,6 @@ class ApertusAttention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
is_neox_style
=
is_neox_style
,
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment