Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
560
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
589 additions
and
232 deletions
+589
-232
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-3
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+2
-3
vllm/model_executor/layers/fused_moe/routing_simulator.py
vllm/model_executor/layers/fused_moe/routing_simulator.py
+5
-3
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+130
-16
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+84
-134
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+3
-2
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+1
-11
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+3
-0
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+49
-6
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+45
-8
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
+3
-0
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
+1
-1
vllm/model_executor/layers/mamba/ops/ssd_combined.py
vllm/model_executor/layers/mamba/ops/ssd_combined.py
+6
-3
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
+62
-20
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+158
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+31
-15
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+0
-3
vllm/model_executor/layers/quantization/auto_round.py
vllm/model_executor/layers/quantization/auto_round.py
+2
-1
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-2
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
560 of 560+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
38d80967
...
...
@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
mk
.
ExpertTokensMetadata
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
mk
.
PrepareResultType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
38d80967
...
...
@@ -420,9 +420,8 @@ def shuffle_weights(
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
...
...
vllm/model_executor/layers/fused_moe/routing_simulator.py
View file @
38d80967
...
...
@@ -10,7 +10,7 @@ like uniform random routing.
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns.
"""
def
__init__
(
self
,
distribution
:
str
=
"uniform"
,
**
distribution_params
):
def
__init__
(
self
,
distribution
:
str
=
"uniform"
,
**
distribution_params
:
Any
):
"""
Initialize distribution-based routing.
...
...
@@ -244,7 +246,7 @@ class RoutingSimulator:
cls
.
_routing_strategies
[
name
]
=
strategy
@
classmethod
def
get_available_strategies
(
cls
):
def
get_available_strategies
(
cls
)
->
list
[
str
]
:
"""
Get list of available routing strategy names.
...
...
vllm/model_executor/layers/layernorm.py
View file @
38d80967
...
...
@@ -9,11 +9,11 @@ import torch.nn as nn
import
vllm.envs
as
envs
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
def
is_rocm_aiter_rmsnorm_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
return
envs
.
VLLM_ROCM_USE_AITER_RMSNORM
\
and
envs
.
VLLM_ROCM_USE_AITER
...
...
@@ -43,7 +43,21 @@ def fused_add_rms_norm(
return
x
,
residual
def
rocm_aiter_rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
poly_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
variance_epsilon
,
)
return
out
def
rocm_aiter_rms_norm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
import
aiter
as
rocm_aiter
if
x
.
dim
()
>
2
:
...
...
@@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return
rocm_aiter
.
rms_norm
(
x
,
weight
,
variance_epsilon
)
def
rocm_aiter_
fused_add_rms_norm
(
def
rocm_aiter_
rmsnorm2d_fwd_with_add_impl
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm(
return
output
,
residual_out
def
dispatch_cuda_rmsnorm_func
(
add_residual
:
bool
):
if
add_residual
:
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_fused_add_rms_norm
return
fused_add_rms_norm
def
rocm_aiter_rms_norm_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
def
rocm_aiter_rmsnorm2d_fwd_with_add_fake
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
op_name
=
"rocm_aiter_rms_norm"
,
op_func
=
rocm_aiter_rms_norm_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rms_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_func
=
rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
if
is_rocm_aiter_rmsnorm_enabled
():
return
rocm_aiter_rms_norm
def
dispatch_rocm_rmsnorm_func
(
with_fused_add
:
bool
,
dtype
:
torch
.
dtype
):
use_aiter
=
is_rocm_aiter_rmsnorm_enabled
()
and
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
if
use_aiter
and
with_fused_add
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm2d_fwd_with_add
if
use_aiter
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rms_norm
# fall back to CUDA implementation
if
with_fused_add
:
return
fused_add_rms_norm
return
rms_norm
...
...
@@ -114,6 +162,13 @@ class RMSNorm(CustomOp):
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
weight_dtype
=
self
.
weight
.
data
.
dtype
if
current_platform
.
is_rocm
():
self
.
rocm_norm_func
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
False
,
dtype
=
weight_dtype
)
self
.
rocm_norm_func_with_add
=
dispatch_rocm_rmsnorm_func
(
with_fused_add
=
True
,
dtype
=
weight_dtype
)
def
forward_native
(
self
,
...
...
@@ -162,13 +217,27 @@ class RMSNorm(CustomOp):
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
norm_func
=
dispatch_cuda_rmsnorm_func
(
add_residual
)
if
add_residual
:
return
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
add_residual
=
residual
is
not
None
if
add_residual
:
return
norm_func
(
x
,
residual
,
self
.
weight
.
data
,
return
self
.
rocm_norm_func_with_add
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
else
:
return
norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
self
.
rocm_norm_func
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
def
forward_xpu
(
self
,
...
...
@@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
self
.
forward_static
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
@
CustomOp
.
register
(
"poly_norm"
)
class
PolyNorm
(
CustomOp
):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def
__init__
(
self
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
3
)
/
3
)
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
variance_epsilon
=
eps
def
_norm
(
self
,
x
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
variance_epsilon
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype
=
x
.
dtype
x_float
=
x
.
to
(
torch
.
float32
)
output
=
(
self
.
weight
[
0
]
*
self
.
_norm
(
x_float
**
3
)
+
self
.
weight
[
1
]
*
self
.
_norm
(
x_float
**
2
)
+
self
.
weight
[
2
]
*
self
.
_norm
(
x_float
)
+
self
.
bias
)
return
output
.
to
(
orig_dtype
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
poly_norm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
vllm/model_executor/layers/linear.py
View file @
38d80967
...
...
@@ -9,7 +9,6 @@ import torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
envs
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
...
...
@@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# special postprocessing for CPU SGL
if
current_platform
.
is_cpu
()
and
envs
.
VLLM_CPU_SGL_KERNEL
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
N
,
K
=
layer
.
weight
.
size
()
dtype
=
layer
.
weight
.
dtype
if
check_cpu_sgl_kernel
(
N
,
K
,
dtype
):
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
layer
.
weight
)
assert
packed_weight
.
size
()
==
layer
.
weight
.
size
()
layer
.
weight
.
copy_
(
packed_weight
)
if
layer
.
bias
is
not
None
:
layer
.
bias
=
Parameter
(
layer
.
bias
.
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
use_cpu_sgl
=
True
else
:
logger
.
warning
(
"CPU SGL kernels require Intel AMX support,"
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16."
)
layer
.
use_cpu_sgl
=
False
if
current_platform
.
is_cpu
():
from
vllm.model_executor.layers.utils
import
(
dispatch_cpu_unquantized_gemm
)
dispatch_cpu_unquantized_gemm
(
layer
,
remove_weight
=
True
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -240,6 +223,7 @@ class LinearBase(CustomOp):
quant_config: Quantization configure.
prefix: Prefix for parameter names.
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, tensor parallelism will be disabled for this layer.
"""
def
__init__
(
...
...
@@ -252,6 +236,7 @@ class LinearBase(CustomOp):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
super
().
__init__
()
...
...
@@ -271,6 +256,17 @@ class LinearBase(CustomOp):
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
self
.
return_bias
=
return_bias
self
.
disable_tp
=
disable_tp
self
.
tp_rank
=
(
get_tensor_model_parallel_rank
()
if
not
disable_tp
else
0
)
self
.
tp_size
=
(
get_tensor_model_parallel_world_size
()
if
not
disable_tp
else
1
)
def
update_param_tp_status
(
self
):
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
BasevLLMParameter
):
param
.
tp_rank
=
self
.
tp_rank
param
.
tp_size
=
self
.
tp_size
@
CustomOp
.
register
(
"replicated_linear"
)
...
...
@@ -287,6 +283,7 @@ class ReplicatedLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def
__init__
(
...
...
@@ -300,26 +297,21 @@ class ReplicatedLinear(LinearBase):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
# If MergedReplicatedLinear, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
self
.
output_sizes
else
:
self
.
output_partition_sizes
=
[
output_size
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
disable_tp
=
disable_tp
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_partition_sizes
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
...
...
@@ -375,74 +367,6 @@ class ReplicatedLinear(LinearBase):
return
s
class
MergedReplicatedLinear
(
ReplicatedLinear
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
list
[
int
],
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
):
self
.
output_sizes
=
output_sizes
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
def
weight_loader
(
self
,
param
:
Union
[
Parameter
,
BasevLLMParameter
],
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
assert
loaded_shard_id
is
not
None
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
assert
self
.
quant_method
is
not
None
assert
isinstance
(
self
.
quant_method
,
(
Fp8LinearMethod
,
Fp8MoEMethod
))
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
assert
weight_block_size
is
not
None
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
shard_size
=
((
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
)
elif
isinstance
(
param
,
PerTensorScaleParameter
):
shard_offset
=
loaded_shard_id
shard_size
=
1
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
param
.
data
[
shard_offset
:
shard_offset
+
shard_size
]
=
loaded_weight
@
CustomOp
.
register
(
"column_parallel_linear"
)
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
...
...
@@ -466,6 +390,8 @@ class ColumnParallelLinear(LinearBase):
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def
__init__
(
...
...
@@ -481,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
(
get_tensor_model_parallel_rank
()
if
not
disable_tp
else
0
)
self
.
tp_size
=
(
get_tensor_model_parallel_world_size
()
if
not
disable_tp
else
1
)
self
.
input_size_per_partition
=
input_size
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
...
...
@@ -500,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
params_dtype
,
quant_config
,
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
disable_tp
=
disable_tp
)
self
.
gather_output
=
gather_output
...
...
@@ -528,8 +459,7 @@ class ColumnParallelLinear(LinearBase):
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
update_param_tp_status
()
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
...
...
@@ -571,7 +501,8 @@ class ColumnParallelLinear(LinearBase):
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
...
...
@@ -587,7 +518,7 @@ class ColumnParallelLinear(LinearBase):
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
and
self
.
tp_size
>
1
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
...
...
@@ -601,7 +532,7 @@ class ColumnParallelLinear(LinearBase):
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size_per_partition
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world
_size
()
}
"
s
+=
f
", tp_size=
{
self
.
tp
_size
}
"
s
+=
f
", gather_output=
{
self
.
gather_output
}
"
return
s
...
...
@@ -628,6 +559,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def
__init__
(
...
...
@@ -642,10 +575,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
self
.
output_sizes
=
output_sizes
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
(
get_tensor_model_parallel_world_size
()
if
not
disable_tp
else
1
)
self
.
tp_rank
=
(
get_tensor_model_parallel_rank
()
if
not
disable_tp
else
0
)
assert
all
(
output_size
%
self
.
tp_size
==
0
for
output_size
in
output_sizes
)
...
...
@@ -657,7 +593,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
disable_tp
=
disable_tp
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -722,8 +659,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack
ed
_factor
shard_offset
=
shard_offset
//
param
.
pack
ed
_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
...
...
@@ -756,8 +693,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack
ed
_factor
shard_offset
=
shard_offset
//
param
.
pack
ed
_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
...
...
@@ -849,8 +786,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
...
...
@@ -862,17 +797,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
(
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
+
block_n
-
1
)
//
block_n
)
//
tp_size
block_n
)
//
self
.
tp_size
shard_size
=
((
self
.
output_sizes
[
loaded_shard_id
]
+
block_n
-
1
)
//
block_n
//
tp_size
)
block_n
//
self
.
tp_size
)
else
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
...
...
@@ -900,6 +837,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def
__init__
(
...
...
@@ -915,6 +853,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
...
...
@@ -923,7 +862,8 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads
=
total_num_heads
self
.
total_num_kv_heads
=
total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
(
get_tensor_model_parallel_world_size
()
if
not
disable_tp
else
1
)
self
.
num_heads
=
divide
(
self
.
total_num_heads
,
tp_size
)
if
tp_size
>=
self
.
total_num_kv_heads
:
self
.
num_kv_heads
=
1
...
...
@@ -949,7 +889,8 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
disable_tp
=
disable_tp
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
...
...
@@ -1010,10 +951,13 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id
:
Optional
[
str
]
=
None
):
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
,
tp_rank
=
self
.
tp_rank
)
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
tp_rank
=
self
.
tp_rank
)
return
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
...
...
@@ -1037,7 +981,8 @@ class QKVParallelLinear(ColumnParallelLinear):
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_size
=
shard_size
,
tp_rank
=
self
.
tp_rank
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -1107,8 +1052,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack
ed
_factor
shard_offset
=
shard_offset
//
param
.
pack
ed
_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
...
@@ -1155,8 +1100,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack
ed
_factor
shard_offset
=
shard_offset
//
param
.
pack
ed
_factor
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
...
@@ -1243,6 +1188,7 @@ class RowParallelLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def
__init__
(
...
...
@@ -1258,10 +1204,13 @@ class RowParallelLinear(LinearBase):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
disable_tp
:
bool
=
False
,
):
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
(
get_tensor_model_parallel_rank
()
if
not
disable_tp
else
0
)
self
.
tp_size
=
(
get_tensor_model_parallel_world_size
()
if
not
disable_tp
else
1
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
...
...
@@ -1272,7 +1221,8 @@ class RowParallelLinear(LinearBase):
params_dtype
,
quant_config
,
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
disable_tp
=
disable_tp
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
...
...
@@ -1301,6 +1251,7 @@ class RowParallelLinear(LinearBase):
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
update_param_tp_status
()
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
...
...
@@ -1356,10 +1307,9 @@ class RowParallelLinear(LinearBase):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
self
.
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
...
...
vllm/model_executor/layers/logits_processor.py
View file @
38d80967
...
...
@@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs
.
VLLM_LOGITS_PROCESSOR_THREADS
)
class
LogitsProcessor
(
nn
.
Module
):
@
CustomOp
.
register
(
"logits_processor"
)
class
LogitsProcessor
(
CustomOp
):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
38d80967
...
...
@@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
variance
=
tensor_model_parallel_all_reduce
(
variance
)
/
self
.
tp_world
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
weight
=
self
.
weight
if
x
.
size
(
-
1
)
!=
self
.
weight
.
size
(
0
):
if
self
.
weight
.
size
(
0
)
<
x
.
size
(
-
1
):
repeat_count
=
(
x
.
size
(
-
1
)
+
self
.
weight
.
size
(
0
))
//
x
.
size
(
-
1
)
full_weight
=
self
.
weight
.
repeat
(
repeat_count
)
weight
=
full_weight
[:
x
.
size
(
-
1
)]
else
:
weight
=
self
.
weight
[:
x
.
size
(
-
1
)]
x
=
x
.
to
(
orig_dtype
)
*
weight
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
return
x
def
forward
(
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
38d80967
...
...
@@ -291,6 +291,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size
=
self
.
conv_dim
,
bias
=
use_conv_bias
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
...
...
@@ -303,6 +304,7 @@ class MambaMixer2(MambaBase, CustomOp):
output_size
=
intermediate_size
+
self
.
conv_dim
+
self
.
num_heads
,
bias
=
use_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj"
,
)
# - because in_proj is a concatenation of 3 weights, we
...
...
@@ -402,6 +404,7 @@ class MambaMixer2(MambaBase, CustomOp):
bias
=
use_bias
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
norm
=
Mixer2RMSNormGated
(
intermediate_size
,
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
38d80967
...
...
@@ -30,12 +30,8 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype
:
MambaDType
,
mamba_ssm_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
...]:
# TODO (tdoublep) requires kernel changes
if
mamba_cache_dtype
==
"float32"
or
mamba_ssm_cache_dtype
==
"float32"
:
raise
ValueError
(
"fp32 state for mamba1 is not yet supported"
)
else
:
return
MambaStateDtypeCalculator
.
mamba2_state_dtype
(
model_dtype
,
mamba_cache_dtype
,
mamba_ssm_cache_dtype
)
return
cls
.
_mamba_state_dtype
(
model_dtype
,
mamba_cache_dtype
,
mamba_ssm_cache_dtype
)
@
classmethod
def
mamba2_state_dtype
(
...
...
@@ -43,6 +39,16 @@ class MambaStateDtypeCalculator:
model_dtype
:
Union
[
ModelDType
,
torch
.
dtype
],
mamba_cache_dtype
:
MambaDType
,
mamba_ssm_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
...]:
return
cls
.
_mamba_state_dtype
(
model_dtype
,
mamba_cache_dtype
,
mamba_ssm_cache_dtype
)
@
classmethod
def
_mamba_state_dtype
(
cls
,
model_dtype
:
Union
[
ModelDType
,
torch
.
dtype
],
mamba_cache_dtype
:
MambaDType
,
mamba_ssm_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
...]:
conv_state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
...
...
@@ -64,6 +70,15 @@ class MambaStateDtypeCalculator:
model_dtype
)
return
(
conv_state_dtype
,
)
@
classmethod
def
gated_delta_net_state_dtype
(
cls
,
model_dtype
:
Union
[
ModelDType
,
torch
.
dtype
],
mamba_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
return
(
state_dtype
,
state_dtype
)
class
MambaStateShapeCalculator
:
...
...
@@ -157,3 +172,31 @@ class MambaStateShapeCalculator:
# for n_groups == 1, this is exactly tp_size - n_groups
return
tp_size
-
ngroups
@
classmethod
def
gated_delta_net_state_shape
(
cls
,
tp_world_size
:
int
,
num_k_heads
:
int
,
num_v_heads
:
int
,
head_k_dim
:
int
,
head_v_dim
:
int
,
conv_kernel_size
:
int
,
num_spec
:
int
=
0
,
use_v1
:
bool
=
True
,
):
conv_dim
=
(
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
)
conv_state_shape
=
(
divide
(
conv_dim
,
tp_world_size
),
conv_kernel_size
-
1
+
num_spec
,
)
# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
temporal_state_shape
=
(
divide
(
num_v_heads
,
tp_world_size
),
head_k_dim
,
head_v_dim
)
return
conv_state_shape
,
temporal_state_shape
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
38d80967
...
...
@@ -464,7 +464,9 @@ def causal_conv1d_fn(
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines
=
conv_states
.
size
(
0
)
assert
(
num_cache_lines
,
dim
,
width
-
1
)
==
conv_states
.
shape
assert
(
num_cache_lines
==
conv_states
.
shape
[
0
]
and
dim
==
conv_states
.
shape
[
1
]
and
width
-
1
<=
conv_states
.
shape
[
2
])
stride_istate_seq
=
conv_states
.
stride
(
0
)
stride_istate_dim
=
conv_states
.
stride
(
1
)
stride_istate_token
=
conv_states
.
stride
(
2
)
...
...
@@ -623,6 +625,7 @@ def _causal_conv1d_update_kernel(
conv_state_ptr
,
cache_seqlens_ptr
,
# circular buffer
conv_state_indices_ptr
,
num_accepted_tokens_ptr
,
o_ptr
,
# (batch, dim, seqlen)
# Matrix dimensions
batch
:
int
,
...
...
@@ -639,6 +642,7 @@ def _causal_conv1d_update_kernel(
stride_conv_state_seq
:
tl
.
constexpr
,
stride_conv_state_dim
:
tl
.
constexpr
,
stride_conv_state_tok
:
tl
.
constexpr
,
stride_state_indices
:
tl
.
constexpr
,
stride_o_seq
:
tl
.
constexpr
,
stride_o_dim
:
tl
.
constexpr
,
stride_o_token
:
tl
.
constexpr
,
...
...
@@ -649,6 +653,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH
:
tl
.
constexpr
,
SILU_ACTIVATION
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
NP2_STATELEN
:
tl
.
constexpr
,
USE_PAD_SLOT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -663,7 +668,8 @@ def _causal_conv1d_update_kernel(
if
IS_CONTINUOUS_BATCHING
:
# mask = idx_seq < batch
conv_state_batch_coord
=
tl
.
load
(
conv_state_indices_ptr
+
idx_seq
).
to
(
conv_state_batch_coord
=
tl
.
load
(
conv_state_indices_ptr
+
idx_seq
*
stride_state_indices
).
to
(
tl
.
int64
)
else
:
conv_state_batch_coord
=
idx_seq
...
...
@@ -672,13 +678,32 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return
if
IS_SPEC_DECODING
:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset
=
(
tl
.
load
(
num_accepted_tokens_ptr
+
idx_seq
)
-
1
)
else
:
conv_state_token_offset
=
0
# STEP 1: READ init_state data
conv_states_base
=
(
conv_state_ptr
+
(
conv_state_batch_coord
*
stride_conv_state_seq
)
+
(
idx_feats
*
stride_conv_state_dim
))
mask_w
=
idx_feats
<
dim
prior_tokens
=
conv_states_base
prior_tokens
=
conv_states_base
+
conv_state_token_offset
*
stride_conv_state_tok
if
KERNEL_WIDTH
>=
2
:
conv_states_ptrs
=
prior_tokens
# [BLOCK_N]
col0
=
tl
.
load
(
conv_states_ptrs
,
mask_w
,
0.0
)
...
...
@@ -695,11 +720,15 @@ def _causal_conv1d_update_kernel(
# STEP 2: assume state_len > seqlen
idx_tokens
=
tl
.
arange
(
0
,
NP2_STATELEN
)
# [BLOCK_M]
# With speculative decoding, the conv_state updates works in a sliding
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source
=
(
conv_state_ptr
+
(
conv_state_batch_coord
*
stride_conv_state_seq
)
+
conv_state_token_offset
*
stride_conv_state_tok
+
(
idx_feats
*
stride_conv_state_dim
)[
None
,
:]
+
((
idx_tokens
+
seqlen
)
*
stride_conv_state_tok
)[:,
None
]
)
# [BLOCK_M, BLOCK_N]
((
idx_tokens
+
(
1
if
IS_SPEC_DECODING
else
seqlen
)
)
*
stride_conv_state_tok
)[:,
None
]
)
# [BLOCK_M, BLOCK_N]
mask
=
((
conv_state_batch_coord
<
num_cache_lines
)
&
((
idx_tokens
+
seqlen
)
<
state_len
)[:,
None
]
&
(
idx_feats
<
dim
)[
None
,
:])
...
...
@@ -820,6 +849,7 @@ def causal_conv1d_update(
activation
:
Union
[
bool
,
str
,
None
]
=
None
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
metadata
=
None
,
validate_data
=
False
,
...
...
@@ -890,9 +920,13 @@ def causal_conv1d_update(
)
# X (batch, dim, seqlen)
stride_o_seq
,
stride_o_dim
,
stride_o_token
=
out
.
stride
()
stride_istate_seq
,
stride_istate_dim
,
stride_istate_token
=
conv_state
.
stride
(
)
stride_state_indices
=
conv_state_indices
.
stride
(
0
)
if
conv_state_indices
is
not
None
else
0
if
num_accepted_tokens
is
not
None
:
state_len
=
width
-
1
+
(
seqlen
-
1
)
# effective state_len needed
else
:
state_len
=
width
-
1
np2_statelen
=
triton
.
next_power_of_2
(
state_len
)
...
...
@@ -910,6 +944,7 @@ def causal_conv1d_update(
conv_state
,
cache_seqlens
,
conv_state_indices
,
num_accepted_tokens
,
out
,
# Matrix dimensions
batch
,
...
...
@@ -926,6 +961,7 @@ def causal_conv1d_update(
stride_istate_seq
,
stride_istate_dim
,
stride_istate_token
,
stride_state_indices
,
stride_o_seq
,
stride_o_dim
,
stride_o_token
,
...
...
@@ -936,6 +972,7 @@ def causal_conv1d_update(
KERNEL_WIDTH
=
width
,
SILU_ACTIVATION
=
activation
in
[
"silu"
,
"swish"
],
IS_CONTINUOUS_BATCHING
=
conv_state_indices
is
not
None
,
IS_SPEC_DECODING
=
num_accepted_tokens
is
not
None
,
NP2_STATELEN
=
np2_statelen
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
BLOCK_N
=
256
,
...
...
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
View file @
38d80967
...
...
@@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary
=
tl
.
load
(
dA_cumsum_ptr
+
(
c_off
-
1
)
*
stride_dA_cs_csize
,
mask
=
(((
c_off
-
1
)
>
-
1
)
and
((
c_off
)
<
chunk_size
)),
...
...
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
View file @
38d80967
...
...
@@ -502,7 +502,7 @@ def _chunk_state_varlen_kernel(
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possib
l
ties
# If HAS_INITSTATES==True need to consider two possib
ili
ties
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if
((
start_idx
<
pid_c
*
chunk_size
)
# first chunk
...
...
vllm/model_executor/layers/mamba/ops/ssd_combined.py
View file @
38d80967
...
...
@@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx
and
iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched
and (iv) chunk_offsets
to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states
,
final_states
=
_state_passing_fwd
(
rearrange
(
states
,
"... p n -> ... (p n)"
),
dA_cumsum
[:,
:,
:,
-
1
]
,
dA_cumsum
,
initial_states
=
rearrange
(
initial_states
,
"... p n -> ... (p n)"
)
if
initial_states
is
not
None
else
None
,
seq_idx
=
seq_idx
,
chunk_size
=
chunk_size
,
out_dtype
=
state_dtype
if
state_dtype
is
not
None
else
C
.
dtype
,
is_cont_batched
=
cu_seqlens
is
not
None
)
is_cont_batched
=
cu_seqlens
is
not
None
,
chunk_offsets
=
chunk_offsets
)
states
,
final_states
=
(
rearrange
(
t
,
"... (p n) -> ... p n"
,
n
=
dstate
)
for
t
in
[
states
,
final_states
])
...
...
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
View file @
38d80967
...
...
@@ -31,6 +31,8 @@ def _state_passing_fwd_kernel(
dA_cs_ptr
,
initstates_ptr
,
seq_idx_ptr
,
chunk_offsets_ptr
,
chunk_meta_num
,
# Matrix dimensions
dim
,
nchunks
,
...
...
@@ -51,6 +53,7 @@ def _state_passing_fwd_kernel(
stride_dA_cs_batch
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_dA_cs_csize
,
stride_initstates_batch
,
stride_initstates_head
,
stride_initstates_dim
,
...
...
@@ -66,7 +69,8 @@ def _state_passing_fwd_kernel(
pid_h
=
tl
.
program_id
(
axis
=
2
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
states_ptr
+=
pid_b
*
stride_states_batch
+
pid_h
*
stride_states_head
dA_cs_ptr
+=
pid_b
*
stride_dA_cs_batch
+
pid_h
*
stride_dA_cs_head
dA_cs_ptr
+=
pid_b
*
stride_dA_cs_batch
+
pid_h
*
stride_dA_cs_head
+
(
chunk_size
-
1
)
*
stride_dA_cs_csize
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_h
*
stride_out_head
final_states_ptr
+=
pid_b
*
stride_final_states_batch
+
pid_h
*
stride_final_states_head
if
HAS_INITSTATES
:
...
...
@@ -95,35 +99,62 @@ def _state_passing_fwd_kernel(
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
out_ptrs
+=
stride_out_chunk
seq_idx
=
0
prev_seq_idx_chunk_end
=
0
logical_chunk_idx
=
0
for
c
in
range
(
nchunks
):
new_states
=
tl
.
load
(
states_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs
=
tl
.
load
(
dA_cs_ptr
).
to
(
tl
.
float32
)
scale
=
tl
.
exp
(
dA_cs
)
scale
_mask
=
True
if
HAS_SEQ_IDX
:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new
=
tl
.
load
(
seq_idx_ptr
+
(
min
((
c
+
1
)
*
chunk_size
,
seqlen
)
-
1
)
*
stride_seq_idx_seqlen
)
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end
=
tl
.
load
(
seq_idx_ptr
+
(
min
(
(
c
+
1
)
*
chunk_size
,
seqlen
)
-
1
)
*
stride_seq_idx_seqlen
)
if
HAS_INITSTATES
:
if
IS_CONT_BATCHED
and
seq_idx
!=
seq_idx_
new
:
if
IS_CONT_BATCHED
and
prev_
seq_idx
_chunk_end
!=
seq_idx_
chunk_end
:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs
=
initstates_ptr
+
seq_idx_
new
*
stride_initstates_batch
initstates_ptrs
=
initstates_ptr
+
seq_idx_
chunk_end
*
stride_initstates_batch
# - update state with seq_idx_new's init state
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start
=
tl
.
load
(
seq_idx_ptr
+
min
(
c
*
chunk_size
,
seqlen
)
*
stride_seq_idx_seqlen
)
logical_chunk_idx
+=
seq_idx_chunk_end
-
seq_idx_chunk_start
# - load the chunk offset:
c_off
=
tl
.
load
(
chunk_offsets_ptr
+
logical_chunk_idx
,
mask
=
logical_chunk_idx
<
chunk_meta_num
,
other
=
0
)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if
c_off
>
0
:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary
=
tl
.
load
(
dA_cs_ptr
-
(
chunk_size
-
1
)
*
stride_dA_cs_csize
+
(
c_off
-
1
)
*
stride_dA_cs_csize
,
mask
=
(
c_off
-
1
)
>
-
1
and
c_off
<
chunk_size
,
other
=
0.0
)
dA_cs
-=
dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx
+=
1
else
:
scale
=
tl
.
where
(
seq_idx_new
==
seq_idx
,
scale
,
0.0
)
scale_mask
=
seq_idx_chunk_end
==
prev_seq_idx_chunk_end
prev_seq_idx_chunk_end
=
seq_idx_chunk_end
seq_idx
=
seq_idx_new
scale
=
tl
.
where
(
scale_mask
,
tl
.
exp
(
dA_cs
),
0.0
)
states
=
scale
*
states
+
new_states
if
c
<
nchunks
-
1
:
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
...
...
@@ -136,28 +167,36 @@ def _state_passing_fwd_kernel(
def
_state_passing_fwd
(
states
,
dA_
chunk_
cumsum
,
dA_cumsum
,
initial_states
=
None
,
seq_idx
=
None
,
chunk_size
=
None
,
out_dtype
=
None
,
is_cont_batched
=
False
,
chunk_offsets
=
None
,
):
batch
,
nchunks
,
nheads
,
dim
=
states
.
shape
assert
dA_chunk_cumsum
.
shape
==
(
batch
,
nheads
,
nchunks
)
if
chunk_size
is
None
:
chunk_size
=
dA_cumsum
.
shape
[
-
1
]
else
:
assert
chunk_size
==
dA_cumsum
.
shape
[
-
1
]
assert
dA_cumsum
.
shape
==
(
batch
,
nheads
,
nchunks
,
chunk_size
)
if
initial_states
is
not
None
:
if
is_cont_batched
:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert
seq_idx
is
not
None
,
""
assert
seq_idx
is
not
None
,
"seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert
chunk_offsets
is
not
None
,
"chunk_offsets must be provided for continuous batching"
else
:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert
initial_states
.
shape
==
(
batch
,
nheads
,
dim
)
if
seq_idx
is
not
None
:
assert
chunk_size
is
not
None
seqlen
=
seq_idx
.
shape
[
-
1
]
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
out_dtype
=
states
.
dtype
if
out_dtype
is
None
else
out_dtype
...
...
@@ -173,13 +212,15 @@ def _state_passing_fwd(
states
,
out
,
final_states
,
dA_
chunk_
cumsum
,
dA_cumsum
,
initial_states
,
seq_idx
,
chunk_offsets
,
len
(
chunk_offsets
)
if
chunk_offsets
is
not
None
else
0
,
dim
,
nchunks
,
seqlen
if
seq_idx
is
not
None
else
0
,
chunk_size
if
seq_idx
is
not
None
else
0
,
chunk_size
,
states
.
stride
(
0
),
states
.
stride
(
1
),
states
.
stride
(
2
),
...
...
@@ -191,9 +232,10 @@ def _state_passing_fwd(
final_states
.
stride
(
0
),
final_states
.
stride
(
1
),
final_states
.
stride
(
2
),
dA_chunk_cumsum
.
stride
(
0
),
dA_chunk_cumsum
.
stride
(
2
),
dA_chunk_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
0
),
dA_cumsum
.
stride
(
2
),
dA_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
3
),
*
((
initial_states
.
stride
(
0
),
initial_states
.
stride
(
1
),
initial_states
.
stride
(
2
))
if
initial_states
is
not
None
else
(
0
,
0
,
0
)),
...
...
vllm/model_executor/layers/mla.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
vllm.attention
import
Attention
from
vllm.config
import
CacheConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
@
dataclass
class
MLAModules
:
"""Modules used in MLA.
"""
kv_a_layernorm
:
torch
.
nn
.
Module
kv_b_proj
:
torch
.
nn
.
Module
rotary_emb
:
torch
.
nn
.
Module
o_proj
:
torch
.
nn
.
Module
fused_qkv_a_proj
:
Optional
[
torch
.
nn
.
Module
]
kv_a_proj_with_mqa
:
Optional
[
torch
.
nn
.
Module
]
q_a_layernorm
:
Optional
[
torch
.
nn
.
Module
]
q_b_proj
:
Optional
[
torch
.
nn
.
Module
]
q_proj
:
Optional
[
torch
.
nn
.
Module
]
@
CustomOp
.
register
(
"multi_head_latent_attention"
)
class
MultiHeadLatentAttention
(
CustomOp
):
"""MLA layer registered as CustomOp.
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
scale
:
float
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
mla_modules
:
MLAModules
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
self
.
fused_qkv_a_proj
=
mla_modules
.
fused_qkv_a_proj
self
.
kv_a_proj_with_mqa
=
mla_modules
.
kv_a_proj_with_mqa
self
.
q_a_layernorm
=
mla_modules
.
q_a_layernorm
self
.
q_b_proj
=
mla_modules
.
q_b_proj
self
.
q_proj
=
mla_modules
.
q_proj
self
.
kv_a_layernorm
=
mla_modules
.
kv_a_layernorm
self
.
kv_b_proj
=
mla_modules
.
kv_b_proj
self
.
rotary_emb
=
mla_modules
.
rotary_emb
self
.
o_proj
=
mla_modules
.
o_proj
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self
.
mla_attn
=
Attention
(
num_heads
=
self
.
num_heads
,
head_size
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scale
=
scale
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_mla
=
True
,
# MLA Args
q_lora_rank
=
self
.
q_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_head_dim
=
self
.
qk_head_dim
,
v_head_dim
=
self
.
v_head_dim
,
kv_b_proj
=
self
.
kv_b_proj
,
)
self
.
prefix
=
prefix
self
.
debug_layer_idx
=
int
(
self
.
prefix
.
split
(
"."
)[
-
2
])
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
q_c
=
None
kv_lora
=
None
if
self
.
q_lora_rank
is
not
None
:
assert
self
.
fused_qkv_a_proj
is
not
None
,
\
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert
self
.
q_a_layernorm
is
not
None
,
\
"q_a_layernorm is required when q_lora_rank is not None"
assert
self
.
q_b_proj
is
not
None
,
\
"q_b_proj is required when q_lora_rank is not None"
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
)[
0
]
q_c
,
kv_lora
=
qkv_lora
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
,
)
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
assert
self
.
kv_a_proj_with_mqa
is
not
None
,
\
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert
self
.
q_proj
is
not
None
,
\
"q_proj is required when q_lora_rank is None"
kv_lora
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
kv_lora
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# Add head dim of 1 to k_pe
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
]
def
forward_cuda
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward_native
(
*
args
,
**
kwargs
)
vllm/model_executor/layers/pooler.py
View file @
38d80967
...
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
itertools
import
groupby
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
normalize
(
pooled_data
.
float
(),
p
=
2
,
dim
=-
1
)
return
x
.
to
(
pooled_data
.
dtype
)
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
...
...
@@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data
.
shape
[
-
1
])
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
.
float
()
,
dim
=-
1
)
.
to
(
pooled_data
.
dtype
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
...
...
@@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from
vllm.model_executor.models.adapters
import
_load_st_projector
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
_load_st_projector
(
self
.
projector
:
Optional
[
nn
.
Module
]
=
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
...
...
@@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
projector
=
cast
(
nn
.
Module
,
self
.
projector
)
def
_proj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
y
=
projector
(
x
.
to
(
torch
.
float32
))
return
y
.
to
(
orig_dtype
)
pooled_data
=
_proj
(
pooled_data
)
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
get_pooling_params
(
pooling_metadata
)
...
...
@@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
p
.
to
(
self
.
head_dtype
)
for
p
in
pooled_data
]
else
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
# for softmax
...
...
@@ -633,9 +638,15 @@ class ClassifierPooler(Pooler):
)
->
None
:
super
().
__init__
()
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
self
.
pooling
=
pooling
self
.
classifier
=
classifier
self
.
act_fn
=
act_fn
or
PoolerClassify
()
self
.
logit_bias
:
Optional
[
float
]
=
vllm_config
.
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
...
...
@@ -650,10 +661,15 @@ class ClassifierPooler(Pooler):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
pooling_params
=
get_pooling_params
(
pooling_metadata
)
flags
=
[
p
.
activation
for
p
in
pooling_params
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
38d80967
...
...
@@ -26,7 +26,6 @@ QuantizationMethods = Literal[
"bitsandbytes"
,
"hqq"
,
"experts_int8"
,
"neuron_quant"
,
"ipex"
,
"quark"
,
"moe_wna16"
,
...
...
@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.mxfp4
import
Mxfp4Config
from
.neuron_quant
import
NeuronQuantConfig
from
.petit
import
PetitNvFp4Config
from
.ptpc_fp8
import
PTPCFp8Config
from
.rtn
import
RTNConfig
...
...
@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"ptpc_fp8"
:
PTPCFp8Config
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
,
"moe_wna16"
:
MoeWNA16Config
,
...
...
vllm/model_executor/layers/quantization/auto_round.py
View file @
38d80967
...
...
@@ -327,6 +327,8 @@ class AutoRoundConfig(QuantizationConfig):
if
isinstance
(
layer
,
FusedMoE
):
if
use_marlin
:
return
GPTQMarlinMoEMethod
(
quant_args_marlin
,
layer
.
moe
)
else
:
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
MoeWNA16Config
)
...
...
@@ -339,7 +341,6 @@ class AutoRoundConfig(QuantizationConfig):
}
return
MoeWNA16Config
.
from_config
(
config
).
get_quant_method
(
layer
,
prefix
)
return
GPTQMarlinMoEMethod
(
quant_args_marlin
,
layer
.
moe
)
if
isinstance
(
layer
,
(
LinearBase
,
ParallelLMHead
)):
if
use_marlin
:
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
from
torch.nn
import
Parameter
...
...
@@ -505,7 +505,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
38d80967
...
...
@@ -19,7 +19,7 @@ def awq_dequantize_kernel(
num_rows
,
# input num rows in qweight
BLOCK_SIZE_X
:
tl
.
constexpr
,
BLOCK_SIZE_Y
:
tl
.
constexpr
):
# Setup the pids.
# Set
up the pids.
pid_x
=
tl
.
program_id
(
axis
=
0
)
pid_y
=
tl
.
program_id
(
axis
=
1
)
...
...
Prev
1
…
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment