Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b52805
Commit
d2b52805
authored
Sep 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori
parents
9a521c23
5438967f
Changes
511
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
789 additions
and
205 deletions
+789
-205
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+24
-0
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+270
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+113
-125
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-6
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+82
-28
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+113
-15
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+3
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+8
-30
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
+169
-0
No files found.
Too many changes to show.
To preserve performance only
511 of 511+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
d2b52805
...
@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
...
@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return
(
conv_state_dtype
,
temporal_state_dtype
)
return
(
conv_state_dtype
,
temporal_state_dtype
)
@
classmethod
def
short_conv_state_dtype
(
cls
,
model_dtype
:
Union
[
ModelDType
,
torch
.
dtype
],
mamba_cache_dtype
:
MambaDType
,
)
->
tuple
[
torch
.
dtype
,
...]:
conv_state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
return
(
conv_state_dtype
,
)
class
MambaStateShapeCalculator
:
class
MambaStateShapeCalculator
:
...
@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
...
@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size
),
head_dim
,
state_size
)
tp_world_size
),
head_dim
,
state_size
)
return
conv_state_shape
,
temporal_state_shape
return
conv_state_shape
,
temporal_state_shape
@
classmethod
def
short_conv_state_shape
(
cls
,
tp_world_size
:
int
,
intermediate_size
:
int
,
conv_kernel
:
int
,
use_v1
:
bool
=
True
,
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
if
not
use_v1
:
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
(
conv_state_shape
,
)
@
classmethod
@
classmethod
def
extra_groups_for_head_shards
(
cls
,
ngroups
:
int
,
tp_size
:
int
):
def
extra_groups_for_head_shards
(
cls
,
ngroups
:
int
,
tp_size
:
int
):
"""Compute the increase in group numbers to account for
"""Compute the increase in group numbers to account for
...
...
vllm/model_executor/layers/mamba/short_conv.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
ModelConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
update_metadata
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionMetadata
)
@
CustomOp
.
register
(
"short_conv"
)
class
ShortConv
(
MambaBase
,
CustomOp
):
def
__init__
(
self
,
config
,
dim
:
int
,
layer_idx
:
int
,
model_config
:
Optional
[
ModelConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
layer_idx
=
layer_idx
self
.
conv_dim
=
dim
self
.
L_cache
=
config
.
conv_L_cache
self
.
bias
=
config
.
conv_bias
self
.
conv
=
ColumnParallelLinear
(
input_size
=
self
.
L_cache
,
output_size
=
dim
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv
.
weight
.
data
=
self
.
conv
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
MergedColumnParallelLinear
(
input_size
=
dim
,
output_sizes
=
[
dim
]
*
3
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.in_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
dim
,
output_size
=
dim
,
bias
=
self
.
bias
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
assert
envs
.
VLLM_USE_V1
,
(
"ShortConv layers are only supported in V1"
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
self
.
kv_cache
=
[(
torch
.
tensor
([]),
)]
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
prefix
=
prefix
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
return
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
torch
.
ops
.
vllm
.
short_conv
(
hidden_states
,
output
,
self
.
prefix
,
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
conv_metadata
:
ShortConvAttentionMetadata
,
):
forward_context
=
get_forward_context
()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
conv_metadata
=
attn_metadata
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
has_initial_states_p
=
attn_metadata
.
has_initial_states
BCx
,
_
=
self
.
in_proj
(
hidden_states
)
B
,
C
,
x
=
BCx
.
chunk
(
3
,
dim
=-
1
)
conv_weights
=
self
.
conv
.
weight
.
view
(
self
.
conv
.
weight
.
size
(
0
),
self
.
conv
.
weight
.
size
(
2
))
if
attn_metadata
is
None
:
# V1 profile run
Bx
=
(
B
*
x
).
contiguous
()
hidden_states
=
C
*
Bx
contextualized_states
,
_
=
self
.
out_proj
(
hidden_states
)
return
contextualized_states
num_prefills
=
attn_metadata
.
num_prefills
# request count
num_decodes
=
attn_metadata
.
num_decode_tokens
# token count (=request)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
# token count
has_prefill
=
num_prefills
>
0
has_decode
=
num_decodes
>
0
num_actual_tokens
=
num_decodes
+
num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d
,
B_p
=
torch
.
split
(
B
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
C_d
,
C_p
=
torch
.
split
(
C
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
x_d
,
x_p
=
torch
.
split
(
x
[:
num_actual_tokens
],
[
num_decodes
,
num_prefill_tokens
],
dim
=
0
,
)
# Split along batch dimension
state_indices_tensor_d
,
state_indices_tensor_p
=
torch
.
split
(
state_indices_tensor
,
[
num_decodes
,
num_prefills
],
dim
=
0
,
)
query_start_loc_p
=
(
attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decodes
if
has_prefill
else
None
)
conv_output_list
=
[]
if
has_prefill
:
Bx_p
=
(
B_p
*
x_p
).
transpose
(
0
,
1
)
if
conv_metadata
.
cu_seqlen
is
None
:
conv_metadata
=
update_metadata
(
Bx_p
,
query_start_loc_p
,
conv_metadata
)
Bx
=
causal_conv1d_fn
(
Bx_p
,
conv_weights
,
self
.
conv
.
bias
,
activation
=
None
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_states_p
,
cache_indices
=
state_indices_tensor_p
,
metadata
=
conv_metadata
,
query_start_loc
=
query_start_loc_p
).
transpose
(
0
,
1
)[:
num_prefill_tokens
]
y
=
C_p
*
Bx
conv_output_list
.
append
(
y
)
if
has_decode
:
Bx_d
=
(
B_d
*
x_d
).
contiguous
()
Bx
=
causal_conv1d_update
(
Bx_d
,
conv_state
,
conv_weights
,
self
.
conv
.
bias
,
activation
=
None
,
conv_state_indices
=
state_indices_tensor_d
)
y
=
C_d
*
Bx
conv_output_list
.
insert
(
0
,
y
)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states
=
torch
.
vstack
(
conv_output_list
)
# Final linear projection
output
[:
num_actual_tokens
],
_
=
self
.
out_proj
(
hidden_states
)
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
...]:
assert
self
.
model_config
is
not
None
assert
self
.
cache_config
is
not
None
return
MambaStateDtypeCalculator
.
short_conv_state_dtype
(
self
.
model_config
.
dtype
,
self
.
cache_config
.
mamba_cache_dtype
,
)
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...]]:
return
MambaStateShapeCalculator
.
short_conv_state_shape
(
tp_world_size
=
get_tensor_model_parallel_world_size
(),
intermediate_size
=
self
.
conv_dim
,
conv_kernel
=
self
.
L_cache
,
)
@
property
def
mamba_type
(
self
)
->
str
:
return
"short_conv"
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
from
vllm.v1.attention.backends.short_conv_attn
import
(
ShortConvAttentionBackend
)
return
ShortConvAttentionBackend
def
short_conv
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_cuda
(
hidden_states
=
hidden_states
,
output
=
output
,
conv_metadata
=
None
)
def
short_conv_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"short_conv"
,
op_func
=
short_conv
,
mutates_args
=
[
"output"
],
fake_impl
=
short_conv_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/layers/pooler.py
View file @
d2b52805
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
...
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
enum
import
IntEnum
from
itertools
import
groupby
from
itertools
import
groupby
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,16 +13,15 @@ import torch.nn.functional as F
...
@@ -13,16 +13,15 @@ import torch.nn.functional as F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
from
vllm.logger
import
init_logger
PoolingMetadata
as
V0PoolingMetadata
)
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
current_stream
,
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingCursor
,
PoolingMetadata
logger
=
init_logger
(
__name__
)
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingFn
=
Callable
[
PoolingFn
=
Callable
[
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
...
@@ -126,16 +125,11 @@ def get_prompt_lens(
...
@@ -126,16 +125,11 @@ def get_prompt_lens(
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
pooling_metadata
.
prompt_lens
return
pooling_metadata
.
prompt_lens
return
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
[
0
].
device
).
prompt_lens
def
get_prompt_token_ids
(
def
get_prompt_token_ids
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
torch
.
Tensor
]:
pooling_metadata
:
PoolingMetadata
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
assert
pooling_metadata
.
prompt_token_ids
is
not
None
,
(
assert
pooling_metadata
.
prompt_token_ids
is
not
None
,
(
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
...
@@ -144,17 +138,9 @@ def get_prompt_token_ids(
...
@@ -144,17 +138,9 @@ def get_prompt_token_ids(
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
]
return
[
torch
.
tensor
(
seq_data_i
.
prompt_token_ids
)
for
seq_data_i
in
pooling_metadata
.
seq_data
.
values
()
]
def
get_pooling_params
(
def
get_pooling_params
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
PoolingParams
]:
pooling_metadata
:
PoolingMetadata
)
->
list
[
PoolingParams
]:
if
isinstance
(
pooling_metadata
,
V0PoolingMetadata
):
pooling_params
=
[
p
for
_
,
p
in
pooling_metadata
.
seq_groups
]
else
:
pooling_params
=
pooling_metadata
.
pooling_params
pooling_params
=
pooling_metadata
.
pooling_params
return
pooling_params
return
pooling_params
...
@@ -172,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
...
@@ -172,6 +158,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
return
PoolerClassify
()
...
@@ -191,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
...
@@ -191,11 +186,18 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
fn
=
resolve_obj_by_qualname
(
function_name
)()
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerActivation
.
wraps
(
fn
)
return
Pooler
Score
()
return
Pooler
Classify
()
def
build_output
(
def
build_output
(
all_data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
PoolerOutput
:
all_data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
PoolerOutput
:
# Pooling models D2H & synchronize occurs here
if
isinstance
(
all_data
,
list
):
all_data
=
[
d
.
to
(
"cpu"
,
non_blocking
=
True
)
for
d
in
all_data
]
else
:
all_data
=
all_data
.
to
(
"cpu"
,
non_blocking
=
True
)
current_stream
().
synchronize
()
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
return
PoolerOutput
(
outputs
=
all_outputs
)
return
PoolerOutput
(
outputs
=
all_outputs
)
...
@@ -222,40 +224,21 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -222,40 +224,21 @@ class PoolingMethod(nn.Module, ABC):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
pooling_cursor
=
pooling_metadata
.
pooling_cursor
return
self
.
forward_all
(
hidden_states
,
pooling_cursor
)
if
isinstance
(
hidden_states
,
list
):
return
[
self
.
forward_one
(
h
,
prompt_len
)
for
h
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
)
]
return
self
.
forward_all
(
hidden_states
,
prompt_lens
)
class
CLSPool
(
PoolingMethod
):
class
CLSPool
(
PoolingMethod
):
...
@@ -263,24 +246,15 @@ class CLSPool(PoolingMethod):
...
@@ -263,24 +246,15 @@ class CLSPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with CLS pooling"
return
hidden_states
[
0
]
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
first_token_flat_indices
=
torch
.
zeros_like
(
prompt_lens
)
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
first_token_flat_indices
[
1
:]
+=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)[:
-
1
]
"partial prefill not supported with CLS pooling"
return
hidden_states
[
first_token_flat_indices
]
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
PoolingMethod
):
class
LastPool
(
PoolingMethod
):
...
@@ -288,20 +262,12 @@ class LastPool(PoolingMethod):
...
@@ -288,20 +262,12 @@ class LastPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
hidden_states
[
-
1
]
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
return
hidden_states
[
last_token_flat_indices
]
class
AllPool
(
PoolingMethod
):
class
AllPool
(
PoolingMethod
):
...
@@ -309,22 +275,19 @@ class AllPool(PoolingMethod):
...
@@ -309,22 +275,19 @@ class AllPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
}
return
{
"encode"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with ALL pooling"
return
hidden_states
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
return
list
(
hidden_states
.
split_with_sizes
(
prompt_lens
.
tolist
()))
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
"partial prefill not supported with ALL pooling"
hidden_states_lst
=
list
(
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()))
return
[
hidden_states_lst
[
i
]
for
i
in
pooling_cursor
.
index
]
class
MeanPool
(
PoolingMethod
):
class
MeanPool
(
PoolingMethod
):
...
@@ -332,31 +295,25 @@ class MeanPool(PoolingMethod):
...
@@ -332,31 +295,25 @@ class MeanPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_
one
(
def
forward_
all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
pooling_cursor
:
PoolingCursor
,
)
->
torch
.
Tensor
:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
"partial prefill not supported with MEAN pooling"
"partial prefill not supported with MEAN pooling"
return
hidden_states
.
mean
(
dim
=
0
,
dtype
=
torch
.
float32
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
# Use float32 for torch.cumsum in MeanPool,
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
torch
.
cat
([
start_indices
=
pooling_cursor
.
first_token_indices_gpu
torch
.
tensor
([
0
],
device
=
hidden_states
.
device
),
end_indices
=
pooling_cursor
.
last_token_indices_gpu
torch
.
cumsum
(
prompt_lens
[:
-
1
],
dim
=
0
)
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
])
end_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
return
(
cumsum
[
end_indices
-
1
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
])
/
prompt_lens
.
unsqueeze
(
1
)
hidden_states
[
start_indices
])
/
prompt_lens
.
unsqueeze
(
1
)
...
@@ -409,24 +366,37 @@ class PoolerNormalize(PoolerActivation):
...
@@ -409,24 +366,37 @@ class PoolerNormalize(PoolerActivation):
return
x
.
to
(
pooled_data
.
dtype
)
return
x
.
to
(
pooled_data
.
dtype
)
class
PoolerClassify
(
PoolerActivation
):
class
Pooler
MultiLabel
Classify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
pooled_data
.
shape
[
-
1
]
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
softmax
(
pooled_data
.
float
(),
dim
=-
1
).
to
(
pooled_data
.
dtype
)
class
PoolerClassify
(
PoolerActivation
):
class
PoolerScore
(
PoolerActivation
):
def
__init__
(
self
,
*
,
static_num_labels
:
bool
=
True
)
->
None
:
super
().
__init__
()
if
static_num_labels
:
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
self
.
num_labels
=
getattr
(
vllm_config
.
model_config
.
hf_config
,
"num_labels"
,
0
)
if
self
.
num_labels
==
0
:
logger
.
warning
(
"num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
else
:
self
.
num_labels
=
None
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
pooled_data
.
shape
[
-
1
]
num_labels
=
(
self
.
num_labels
if
self
.
num_labels
is
not
None
else
pooled_data
.
shape
[
-
1
])
if
num_labels
<
2
:
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
pooled_data
return
F
.
softmax
(
pooled_data
.
float
(),
dim
=-
1
).
to
(
pooled_data
.
dtype
)
class
LambdaPoolerActivation
(
PoolerActivation
):
class
LambdaPoolerActivation
(
PoolerActivation
):
...
@@ -457,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -457,9 +427,33 @@ class EmbeddingPoolerHead(PoolerHead):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerNormalize
())
super
().
__init__
(
activation
=
PoolerNormalize
())
# Load ST projector if available
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.models.adapters
import
_load_st_projector
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
pooling_metadata
:
PoolingMetadata
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
projector
=
cast
(
nn
.
Module
,
self
.
projector
)
def
_proj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
y
=
projector
(
x
.
to
(
torch
.
float32
))
return
y
.
to
(
orig_dtype
)
pooled_data
=
_proj
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
get_pooling_params
(
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
# for matryoshka representation
# for matryoshka representation
...
@@ -491,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -491,13 +485,14 @@ class EmbeddingPoolerHead(PoolerHead):
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
return
pooled_data
class
RewardPoolerHead
(
PoolerHead
):
class
RewardPoolerHead
(
PoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
())
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
pooling_metadata
:
PoolingMetadata
):
...
@@ -651,15 +646,13 @@ class ClassifierPooler(Pooler):
...
@@ -651,15 +646,13 @@ class ClassifierPooler(Pooler):
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
# apply classifier once on the full batch if possible
if
isinstance
(
pooled_data
,
torch
.
Tensor
):
pooled_data
=
self
.
classifier
(
pooled_data
)
pooled_data
=
self
.
classifier
(
pooled_data
)
elif
len
({
data
.
shape
for
data
in
pooled_data
})
<=
1
:
# pooled_data shape: [batchsize, num_labels]
pooled_data
=
self
.
classifier
(
torch
.
stack
(
pooled_data
))
else
:
pooled_data
=
[
self
.
classifier
(
data
)
for
data
in
pooled_data
]
pooling_params
=
get_pooling_params
(
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
flags
=
[
p
.
activation
for
p
in
pooling_params
]
flags
=
[
p
.
activation
for
p
in
pooling_params
]
...
@@ -672,6 +665,7 @@ class ClassifierPooler(Pooler):
...
@@ -672,6 +665,7 @@ class ClassifierPooler(Pooler):
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
]
# scores shape: [batchsize, num_labels]
return
build_output
(
scores
)
return
build_output
(
scores
)
...
@@ -702,12 +696,6 @@ class DispatchPooler(Pooler):
...
@@ -702,12 +696,6 @@ class DispatchPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
poolers_by_task
=
self
.
poolers_by_task
if
isinstance
(
hidden_states
,
list
):
hidden_states_lst
=
hidden_states
else
:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states_lst
=
list
(
hidden_states
.
split
(
prompt_lens
.
tolist
()))
outputs
=
list
[
PoolingSequenceGroupOutput
]()
outputs
=
list
[
PoolingSequenceGroupOutput
]()
offset
=
0
offset
=
0
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
...
@@ -718,7 +706,7 @@ class DispatchPooler(Pooler):
...
@@ -718,7 +706,7 @@ class DispatchPooler(Pooler):
num_items
=
len
(
list
(
group
))
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
group_output
:
PoolerOutput
=
pooler
(
hidden_states
_lst
[
offset
:
offset
+
num_items
]
,
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
pooling_metadata
[
offset
:
offset
+
num_items
],
)
)
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
d2b52805
...
@@ -15,7 +15,6 @@ QuantizationMethods = Literal[
...
@@ -15,7 +15,6 @@ QuantizationMethods = Literal[
"fbgemm_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
"modelopt"
,
"modelopt_fp4"
,
"modelopt_fp4"
,
"marlin"
,
"bitblas"
,
"bitblas"
,
"gguf"
,
"gguf"
,
"gptq_marlin_24"
,
"gptq_marlin_24"
,
...
@@ -25,7 +24,6 @@ QuantizationMethods = Literal[
...
@@ -25,7 +24,6 @@ QuantizationMethods = Literal[
"gptq"
,
"gptq"
,
"compressed-tensors"
,
"compressed-tensors"
,
"bitsandbytes"
,
"bitsandbytes"
,
"qqq"
,
"hqq"
,
"hqq"
,
"experts_int8"
,
"experts_int8"
,
"neuron_quant"
,
"neuron_quant"
,
...
@@ -37,6 +35,7 @@ QuantizationMethods = Literal[
...
@@ -37,6 +35,7 @@ QuantizationMethods = Literal[
"rtn"
,
"rtn"
,
"inc"
,
"inc"
,
"mxfp4"
,
"mxfp4"
,
"petit_nvfp4"
,
]
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
...
@@ -106,13 +105,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -106,13 +105,12 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.hqq_marlin
import
HQQMarlinConfig
from
.hqq_marlin
import
HQQMarlinConfig
from
.inc
import
INCConfig
from
.inc
import
INCConfig
from
.ipex_quant
import
IPEXConfig
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.moe_wna16
import
MoeWNA16Config
from
.mxfp4
import
Mxfp4Config
from
.mxfp4
import
Mxfp4Config
from
.neuron_quant
import
NeuronQuantConfig
from
.neuron_quant
import
NeuronQuantConfig
from
.petit
import
PetitNvFp4Config
from
.ptpc_fp8
import
PTPCFp8Config
from
.ptpc_fp8
import
PTPCFp8Config
from
.qqq
import
QQQConfig
from
.rtn
import
RTNConfig
from
.rtn
import
RTNConfig
from
.torchao
import
TorchAOConfig
from
.torchao
import
TorchAOConfig
from
.tpu_int8
import
Int8TpuConfig
from
.tpu_int8
import
Int8TpuConfig
...
@@ -125,7 +123,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -125,7 +123,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptNvFp4Config
,
"modelopt_fp4"
:
ModelOptNvFp4Config
,
"marlin"
:
MarlinConfig
,
"bitblas"
:
BitBLASConfig
,
"bitblas"
:
BitBLASConfig
,
"gguf"
:
GGUFConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
...
@@ -136,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -136,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"ptpc_fp8"
:
PTPCFp8Config
,
"ptpc_fp8"
:
PTPCFp8Config
,
"qqq"
:
QQQConfig
,
"hqq"
:
HQQMarlinConfig
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"neuron_quant"
:
NeuronQuantConfig
,
...
@@ -148,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -148,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn"
:
RTNConfig
,
"rtn"
:
RTNConfig
,
"inc"
:
INCConfig
,
"inc"
:
INCConfig
,
"mxfp4"
:
Mxfp4Config
,
"mxfp4"
:
Mxfp4Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
}
}
# Update the `method_to_config` with customized quantization methods.
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
d2b52805
...
@@ -497,6 +497,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -497,6 +497,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -523,6 +524,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -523,6 +524,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
d2b52805
...
@@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
...
@@ -466,6 +466,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
...
@@ -490,6 +491,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
if
self
.
quant_config
.
load_in_8bit
:
if
self
.
quant_config
.
load_in_8bit
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
d2b52805
...
@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
...
@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationStrategy
,
QuantizationType
)
QuantizationType
)
from
compressed_tensors.transform
import
TransformConfig
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -26,10 +27,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -26,10 +27,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensors24
,
CompressedTensorsScheme
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsScheme
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A8Int
,
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A8Fp8
,
CompressedTensorsW4A8Int
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.linear
import
(
# noqa: E501
CompressedTensorsLinearTransformMethod
,
get_linear_transform_schemes
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
is_activation_quantization_format
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
should_ignore_layer
)
...
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list
:
list
[
str
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
transform_config
:
Optional
[
TransformConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
ignore
=
ignore
self
.
ignore
=
ignore
...
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self
.
sparsity_ignore_list
=
sparsity_ignore_list
self
.
sparsity_ignore_list
=
sparsity_ignore_list
self
.
config
=
config
self
.
config
=
config
if
transform_config
is
not
None
:
self
.
transform_config
=
TransformConfig
.
model_validate
(
transform_config
)
else
:
self
.
transform_config
=
None
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
# collect schemes
if
scheme
is
None
:
quant_scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
return
UnquantizedLinearMethod
()
input_tfms
,
output_tfms
=
get_linear_transform_schemes
(
layer
.
scheme
=
scheme
layer
,
prefix
,
self
.
transform_config
,
return
CompressedTensorsLinearMethod
(
self
)
self
.
packed_modules_mapping
)
# choose quantization method
quant_method
:
LinearMethodBase
=
UnquantizedLinearMethod
()
if
quant_scheme
is
not
None
:
layer
.
scheme
=
quant_scheme
quant_method
=
CompressedTensorsLinearMethod
(
self
)
# choose transform method
if
any
((
input_tfms
,
output_tfms
)):
return
CompressedTensorsLinearTransformMethod
.
from_schemes
(
quant_method
,
input_tfms
,
output_tfms
)
else
:
return
quant_method
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
...
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config
=
config
)
config
=
config
)
sparsity_scheme_map
,
sparsity_ignore_list
=
cls
.
_parse_sparsity_config
(
sparsity_scheme_map
,
sparsity_ignore_list
=
cls
.
_parse_sparsity_config
(
config
=
config
)
config
=
config
)
transform_config
=
config
.
get
(
"transform_config"
)
return
cls
(
return
cls
(
target_scheme_map
=
target_scheme_map
,
target_scheme_map
=
target_scheme_map
,
...
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map
=
sparsity_scheme_map
,
sparsity_scheme_map
=
sparsity_scheme_map
,
sparsity_ignore_list
=
sparsity_ignore_list
,
sparsity_ignore_list
=
sparsity_ignore_list
,
config
=
config
,
config
=
config
,
transform_config
=
transform_config
,
)
)
@
classmethod
@
classmethod
...
@@ -200,8 +221,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -200,8 +221,10 @@ class CompressedTensorsConfig(QuantizationConfig):
format
format
)
if
format
is
not
None
else
is_activation_quantization_format
(
)
if
format
is
not
None
else
is_activation_quantization_format
(
quant_format
)
quant_format
)
if
act_quant_format
:
# TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations
=
quant_config
.
get
(
"input_activations"
)
input_activations
=
quant_config
.
get
(
"input_activations"
)
if
act_quant_format
or
input_activations
:
# The only case where we have activation quant supported
# The only case where we have activation quant supported
# but no input_activations provided in the config
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# should be w8a16fp8 w8a16fp8 can also run for cases where
...
@@ -352,6 +375,28 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -352,6 +375,28 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
return
is_symmetric_activation
and
is_per_tensor_activation
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w4a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
if
not
weight_quant
or
not
input_quant
:
return
False
is_weight_4_bits
=
weight_quant
.
num_bits
==
4
is_activation_8_bits
=
input_quant
.
num_bits
==
8
weight_strategy
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
is_token
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
.
value
)
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return
(
is_weight_4_bits
and
is_activation_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
)
def
_is_fp8_w4a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w4a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
...
@@ -401,19 +446,30 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -401,19 +446,30 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant
:
BaseModel
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
,
input_quant
:
BaseModel
,
format
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
format
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
# use the per-layer format if defined, otherwise, use global format
format
=
format
if
format
is
not
None
else
self
.
quant_format
# Detect If Mixed Precision
# Detect If Mixed Precision
if
self
.
_is_fp4a16_nvfp4
(
weight_quant
,
input_quant
):
if
self
.
_is_fp4a16_nvfp4
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A16Fp4
()
return
CompressedTensorsW4A16Fp4
()
if
self
.
_is_fp8_w4a8_sm90
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A8Fp8
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
symmetric
=
weight_quant
.
symmetric
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
(
self
.
quant_
format
==
CompressionFormat
.
marlin_24
.
value
if
(
format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
assert
weight_quant
.
symmetric
assert
weight_quant
.
symmetric
return
CompressedTensorsW4A16Sparse24
(
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
num_bits
=
weight_quant
.
num_bits
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
)
if
(
self
.
quant_
format
==
CompressionFormat
.
pack_quantized
.
value
if
(
format
==
CompressionFormat
.
pack_quantized
.
value
and
weight_quant
.
num_bits
in
WNA16_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
WNA16_SUPPORTED_BITS
):
return
CompressedTensorsWNA16
(
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
num_bits
=
weight_quant
.
num_bits
,
...
@@ -422,10 +478,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -422,10 +478,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size
=
weight_quant
.
group_size
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
actorder
=
weight_quant
.
actorder
)
act_quant_format
=
is_activation_quantization_format
(
act_quant_format
=
is_activation_quantization_format
(
format
)
format
)
if
format
is
not
None
else
is_activation_quantization_format
(
self
.
quant_format
)
if
act_quant_format
:
if
act_quant_format
:
if
self
.
_is_fp4a4_nvfp4
(
weight_quant
,
input_quant
):
if
self
.
_is_fp4a4_nvfp4
(
weight_quant
,
input_quant
):
if
cutlass_fp4_supported
(
if
cutlass_fp4_supported
(
...
@@ -505,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -505,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# TODO (@kylesayrs): support ignore module names with ct matching utils
# so we do not have to re-write these functions
if
should_ignore_layer
(
layer_name
,
# need to make accelerate optional in ct to do this
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
None
# Will be empty for models with only sparsity
# Will be empty for models with only sparsity
weight_quant
=
input_quant
=
None
weight_quant
=
input_quant
=
None
...
@@ -524,7 +579,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -524,7 +579,7 @@ class CompressedTensorsConfig(QuantizationConfig):
format
=
scheme_dict
.
get
(
"format"
)
format
=
scheme_dict
.
get
(
"format"
)
# Find the sparsity scheme of the layer
# Find the sparsity scheme of the layer
# assume that fused layers iner
h
it first component's sparsity scheme
# assume that fused layers in
h
erit first component's sparsity scheme
sparsity_targets
=
(
self
.
sparsity_scheme_map
.
keys
()
-
sparsity_targets
=
(
self
.
sparsity_scheme_map
.
keys
()
-
set
(
self
.
sparsity_ignore_list
))
set
(
self
.
sparsity_ignore_list
))
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
...
@@ -690,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -690,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
layer input. See LinearMethodBase for param details
"""
"""
scheme
=
layer
.
scheme
scheme
=
layer
.
scheme
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d2b52805
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe
)
is_valid_flashinfer_cutlass_fused_moe
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
WNA16_SUPPORTED_BITS
,
WNA16_SUPPORTED_TYPES_MAP
)
WNA16_SUPPORTED_BITS
,
WNA16_SUPPORTED_TYPES_MAP
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
reorder_w1w3_to_w3w1
,
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
reorder_w1w3_to_w3w1
,
...
@@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@
staticmethod
@
staticmethod
def
get_moe_method
(
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsMoEMethod"
:
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
# Check if a using "Linear" to select schemes
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
if
"Linear"
in
quant_config
.
target_scheme_map
:
matched_target
=
"Linear"
else
:
# May have instead defined the linear layers in the fused model
fused_layers
=
[
"re:.*down_proj.*"
,
"re:.*gate_proj.*"
,
"re:.*up_proj.*"
]
current_scheme
=
None
for
fused_layer
in
fused_layers
:
# Check if one of the fused layers are defined in quant_config
matched_target
=
find_matched_target
(
layer_name
=
fused_layer
,
module
=
layer
,
targets
=
quant_config
.
target_scheme_map
.
keys
(),
fused_mapping
=
quant_config
.
packed_modules_mapping
)
# Only valid if down_proj, gate_proj, and up_proj
# are mapped to the same quant scheme in the quant_config
if
current_scheme
is
None
:
current_scheme
=
quant_config
.
target_scheme_map
.
get
(
matched_target
)
else
:
assert
current_scheme
==
quant_config
.
target_scheme_map
.
get
(
matched_target
)
weight_quant
=
quant_config
.
target_scheme_map
[
matched_target
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
matched_target
].
get
(
"input_activations"
)
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
...
@@ -246,11 +276,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -246,11 +276,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
return
# swizzle weight scales
# swizzle weight scales
layer
.
w13_
blockscale_swizzled
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
w13_
weight_scale
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
w13_weight_scale
),
layer
.
w13_weight_scale
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
w2_
blockscale_swizzled
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
w2_
weight_scale
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
w2_weight_scale
),
layer
.
w2_weight_scale
),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -292,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -292,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self
,
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
"""Return the appropriate GEMM experts implementation."""
"""Return the appropriate GEMM experts implementation."""
experts
=
select_nvfp4_gemm_impl
(
experts
=
select_nvfp4_gemm_impl
(
...
@@ -319,6 +350,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -319,6 +350,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -344,6 +376,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -344,6 +376,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
indices_type
=
self
.
topk_indices_dtype
,
)
)
...
@@ -383,8 +416,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -383,8 +416,35 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation
=
activation
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_blockscale_swizzled
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_blockscale_swizzled
,
w2_scale
=
layer
.
w2_weight_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
self
.
allow_flashinfer
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
flashinfer_cutlass_moe_fp4
)
assert
is_valid_flashinfer_cutlass_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
),
(
"Flashinfer CUTLASS Fused MoE not applicable!"
)
return
flashinfer_cutlass_moe_fp4
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
# TODO(shuw): fix later, now output is high prec
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
g1_alphas
=
layer
.
g1_alphas
,
g2_alphas
=
layer
.
g2_alphas
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
...
@@ -400,8 +460,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
...
@@ -400,8 +460,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
a
=
x
,
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_fp4
=
layer
.
w13_weight
,
w2_fp4
=
layer
.
w2_weight
,
w2_fp4
=
layer
.
w2_weight
,
w1_blockscale
=
layer
.
w13_
blockscale_swizzled
,
w1_blockscale
=
layer
.
w13_
weight_scale
,
w2_blockscale
=
layer
.
w2_
blockscale_swizzled
,
w2_blockscale
=
layer
.
w2_
weight_scale
,
g1_alphas
=
layer
.
g1_alphas
,
g1_alphas
=
layer
.
g1_alphas
,
g2_alphas
=
layer
.
g2_alphas
,
g2_alphas
=
layer
.
g2_alphas
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
...
@@ -642,11 +702,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -642,11 +702,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
self
.
fused_experts_func
=
fused_experts
if
self
.
use_cutlass
:
device
=
layer
.
w13_weight
.
device
# ab_strides1 and c_strides2 are the same
self
.
ab_strides1_c_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
ab_strides2
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
(
(
layer
.
local_num_experts
,
),
2
*
layer
.
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
def
select_gemm_impl
(
def
select_gemm_impl
(
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
layer
:
torch
.
nn
.
Module
)
->
FusedMoEPermuteExpertsUnpermute
:
# cutlass path
# cutlass path
if
self
.
use_cutlass
:
if
self
.
use_cutlass
:
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
...
@@ -666,6 +744,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -666,6 +744,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
else
:
else
:
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
...
@@ -673,6 +755,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -673,6 +755,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe
.
in_dtype
,
moe
.
in_dtype
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
)
)
self
.
disable_expert_map
=
(
num_dispatchers
>
1
self
.
disable_expert_map
=
(
num_dispatchers
>
1
...
@@ -725,6 +811,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -725,6 +811,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -748,6 +835,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -748,6 +835,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
indices_type
=
self
.
topk_indices_dtype
,
)
)
...
@@ -795,6 +883,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -795,6 +883,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
ab_strides1
=
self
.
ab_strides1_c_strides2
,
ab_strides2
=
self
.
ab_strides2
,
c_strides1
=
self
.
c_strides1
,
c_strides2
=
self
.
ab_strides1_c_strides2
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
)
...
@@ -969,6 +1061,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -969,6 +1061,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -996,6 +1089,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -996,6 +1089,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
...
@@ -1273,6 +1367,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -1273,6 +1367,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -1301,6 +1396,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -1301,6 +1396,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
...
@@ -1504,6 +1600,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1504,6 +1600,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
...
@@ -1530,6 +1627,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1530,6 +1627,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
d2b52805
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_w4a4_nvfp4
import
CompressedTensorsW4A4Fp4
from
.compressed_tensors_w4a4_nvfp4
import
CompressedTensorsW4A4Fp4
from
.compressed_tensors_w4a8_fp8
import
CompressedTensorsW4A8Fp8
from
.compressed_tensors_w4a8_int
import
CompressedTensorsW4A8Int
from
.compressed_tensors_w4a8_int
import
CompressedTensorsW4A8Int
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
CompressedTensorsW4A16Sparse24
)
...
@@ -21,5 +22,6 @@ __all__ = [
...
@@ -21,5 +22,6 @@ __all__ = [
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
"CompressedTensors24"
,
"CompressedTensorsW4A16Fp4"
,
"CompressedTensors24"
,
"CompressedTensorsW4A16Fp4"
,
"CompressedTensorsW4A4Fp4"
,
"CompressedTensorsW4A8Int"
"CompressedTensorsW4A4Fp4"
,
"CompressedTensorsW4A8Int"
,
"CompressedTensorsW4A8Fp8"
]
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
d2b52805
...
@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
run_nvfp4_emulations
)
run_nvfp4_emulations
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
swizzle_blockscale
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
...
@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
tensor
):
assert
(
scale
.
dtype
==
torch
.
float8_e4m3fn
)
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
if
scale
.
ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
assert
scale
.
ndim
==
3
B
,
M
,
K
=
scale
.
shape
round_up_multiple
=
lambda
x
,
m
:
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
round_up_multiple
(
M
,
128
)
K_padded
=
round_up_multiple
(
K
,
4
)
padded_scale
=
torch
.
zeros
((
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
)
padded_scale
[:
B
,
:
M
,
:
K
]
=
scale
batches
,
rows
,
cols
=
padded_scale
.
shape
assert
rows
%
128
==
0
assert
cols
%
4
==
0
padded_scale
=
padded_scale
.
reshape
(
batches
,
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
swizzled_scale
=
padded_scale
.
permute
((
0
,
1
,
4
,
3
,
2
,
5
))
swizzled_scale
=
swizzled_scale
.
contiguous
().
cuda
()
return
(
swizzled_scale
.
reshape
(
M
,
K
)
if
scale_ndim
==
2
else
swizzled_scale
.
reshape
(
B
,
M
,
K
))
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
global_input_scale
=
layer
.
input_global_scale
.
max
().
to
(
torch
.
float32
)
global_input_scale
=
layer
.
input_global_scale
.
max
().
to
(
torch
.
float32
)
...
@@ -133,12 +112,11 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -133,12 +112,11 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
torch
.
uint8
),
epilogue_tile_m
).
reshape
(
torch
.
uint8
),
epilogue_tile_m
).
reshape
(
weight_scale
.
shape
).
view
(
torch
.
float8_e4m3fn
))
weight_scale
.
shape
).
view
(
torch
.
float8_e4m3fn
))
layer
.
weight_scale_swizzled
=
Parameter
(
weight_scale
,
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
else
:
else
:
swizzled_weight_scale
=
self
.
swizzle_blockscale
(
layer
.
weight_scale
)
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
layer
.
weight_scale
_swizzled
=
Parameter
(
swizzled_weight_scale
,
layer
.
weight_scale
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -157,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x
=
x
,
x
=
x
,
input_global_scale
=
layer
.
input_global_scale
,
input_global_scale
=
layer
.
input_global_scale
,
weight
=
layer
.
weight_packed
,
weight
=
layer
.
weight_packed
,
weight_scale_swizzled
=
layer
.
weight_scale
_swizzled
,
weight_scale_swizzled
=
layer
.
weight_scale
,
weight_global_scale
=
layer
.
weight_global_scale
)
weight_global_scale
=
layer
.
weight_global_scale
)
if
bias
is
not
None
:
if
bias
is
not
None
:
out
=
out
+
bias
out
=
out
+
bias
...
@@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -170,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
)
mm_args
=
(
x_fp4
,
layer
.
weight_packed
,
x_blockscale
,
mm_args
=
(
x_fp4
,
layer
.
weight_packed
,
x_blockscale
,
layer
.
weight_scale
_swizzled
,
layer
.
alpha
,
output_dtype
)
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
)
if
self
.
backend
==
"flashinfer-trtllm"
:
if
self
.
backend
==
"flashinfer-trtllm"
:
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"trtllm"
)
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"trtllm"
)
elif
self
.
backend
==
"flashinfer-cutlass"
:
elif
self
.
backend
==
"flashinfer-cutlass"
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
ActivationOrdering
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_repeat_scales_on_all_ranks
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
# yapf: enable
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsW4A8Fp8"
]
W4A8_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
int4
,
}
W4A8_SUPPORTED_BITS
=
list
(
W4A8_SUPPORTED_TYPES_MAP
.
keys
())
class
CompressedTensorsW4A8Fp8
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
,
symmetric
:
Optional
[
bool
]
=
True
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
symmetric
=
symmetric
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
!=
128
or
self
.
strategy
!=
"group"
:
raise
ValueError
(
"W4A8 kernels require group quantization "
\
"with group size 128"
)
if
num_bits
not
in
W4A8_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
W4A8_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
W4A8_SUPPORTED_TYPES_MAP
[
num_bits
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# hopper
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size
:
int
,
input_size
:
int
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_type
,
act_type
=
torch
.
float8_e4m3fn
,
# always use fp8(e4m3)
group_size
=
self
.
group_size
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
,
out_type
=
params_dtype
)
kernel_type
=
choose_mp_linear_kernel
(
mp_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsW4A8Fp8"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# If group_size is -1, we are in channelwise case.
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
assert
input_size_per_partition
%
group_size
==
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
packed_factor
=
self
.
pack_factor
,
packed_dim
=
1
,
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
torch
.
float8_e4m3fn
,
)
}
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
# per-channel scales
weight_chan_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
output_size_per_partition
,
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_chan_scale"
,
weight_chan_scale
)
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_s_param_name
=
"weight_scale"
,
w_zp_param_name
=
"weight_zero_point"
,
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
layer
,
x
,
bias
)
Prev
1
…
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment