Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f0b2da72
Unverified
Commit
f0b2da72
authored
Feb 14, 2025
by
Michael Goin
Committed by
GitHub
Feb 13, 2025
Browse files
Expand MLA to support most types of quantization (#13181)
parent
f2b20fe4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
132 deletions
+61
-132
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+26
-45
vllm/config.py
vllm/config.py
+1
-31
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+34
-56
No files found.
vllm/attention/backends/mla/utils.py
View file @
f0b2da72
...
...
@@ -26,7 +26,7 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_linear_generic
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_dequantize
,
scaled_quantize
)
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
...
...
@@ -220,16 +220,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
is_layer_fp8
(
layer
:
LinearBase
)
->
bool
:
return
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
)
or
\
(
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
))
def
quantization_scheme_supported
(
layer
:
LinearBase
)
->
bool
:
return
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
)
or
\
is_layer_fp8
(
layer
)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
...
...
@@ -239,7 +229,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
is
not
None
:
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
...
...
@@ -267,41 +257,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
get_scales
(
layer
:
LinearBase
)
->
torch
.
Tensor
:
if
hasattr
(
layer
,
"weight_scale_inv"
):
return
layer
.
weight_scale_inv
return
layer
.
weight_scale
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
is_layer_fp8
(
layer
):
if
isinstance
(
layer
.
quant_method
,
\
CompressedTensorsLinearMethod
)
and
\
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight
=
layer
.
weight
.
T
# standardize to (output, input)
def
get_layer_weight
(
layer
):
if
hasattr
(
layer
,
"weight"
):
return
layer
.
weight
elif
hasattr
(
layer
,
"qweight"
):
return
layer
.
qweight
else
:
weight
=
layer
.
weight
_
,
weight_scale_group_shape
=
\
get_scale_group_shapes_for_fp8
(
layer
)
scales
=
get_scales
(
layer
)
raise
AttributeError
(
f
"Layer '
{
layer
}
' has neither weight nor qweight"
)
return
scaled_dequantize
(
weight
,
scales
,
weight_scale_group_shape
)
else
:
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye
=
torch
.
eye
(
layer
.
input_size_per_partition
,
dtype
=
act_dtype
,
device
=
get_layer_weight
(
layer
).
device
)
dequant_weights
=
layer
.
quant_method
.
apply
(
layer
,
eye
,
bias
=
None
)
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
if
not
(
quantization_scheme_supported
(
self
.
kv_b_proj
)
and
\
quantization_scheme_supported
(
self
.
q_proj
)
and
\
quantization_scheme_supported
(
self
.
o_proj
)):
raise
NotImplementedError
(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1"
)
weight_dtype
=
self
.
kv_b_proj
.
weight
.
dtype
assert
self
.
o_proj
.
weight
.
dtype
==
weight_dtype
assert
self
.
q_proj
.
weight
.
dtype
==
weight_dtype
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
...
...
vllm/config.py
View file @
f0b2da72
...
...
@@ -991,37 +991,7 @@ class ModelConfig:
@
property
def
use_mla
(
self
)
->
bool
:
if
not
self
.
is_deepseek_mla
or
envs
.
VLLM_MLA_DISABLE
:
return
False
if
self
.
quantization
is
not
None
and
self
.
quantization
not
in
[
\
"fp8"
,
"compressed-tensors"
]:
logger
.
warning
(
"MLA is not supported with %s quantization. "
"Disabling MLA."
,
self
.
quantization
)
return
False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if
self
.
quantization
==
"compressed-tensors"
:
quant_config
=
self
.
_parse_quant_hf_config
()
for
group_name
,
cfg
in
quant_config
.
get
(
"config_groups"
,
{
""
:
{}
}).
items
():
act_cfg
=
cfg
.
get
(
"input_activations"
,
{})
act_type
=
None
if
act_cfg
is
None
else
act_cfg
.
get
(
"type"
,
""
)
w_cfg
=
cfg
.
get
(
"weights"
,
{})
w_type
=
None
if
w_cfg
is
None
else
w_cfg
.
get
(
"type"
,
""
)
if
act_type
!=
"fp8"
or
w_type
!=
"fp8"
:
logger
.
warning
(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.
\n
"
"Full config: %s"
,
group_name
,
act_type
,
w_type
,
quant_config
)
return
False
return
True
return
self
.
is_deepseek_mla
and
not
envs
.
VLLM_MLA_DISABLE
@
property
def
supported_runner_types
(
self
)
->
Set
[
RunnerType
]:
...
...
vllm/model_executor/model_loader/loader.py
View file @
f0b2da72
...
...
@@ -153,6 +153,30 @@ def _initialize_model(
return
model_class
(
**
kwargs
)
def
_process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
...
...
@@ -376,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader):
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
...
...
@@ -394,23 +417,8 @@ class DefaultModelLoader(BaseModelLoader):
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
...
...
@@ -429,29 +437,15 @@ class DummyModelLoader(BaseModelLoader):
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
t
orch
.
device
(
device_config
.
device
)
:
with
t
arget_
device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
torch
.
device
(
device_config
.
device
)):
quant_method
.
process_weights_after_loading
(
module
)
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# When attention modules need to process weights after
# currently only used by MLA
module
.
process_weights_after_loading
(
model_config
.
dtype
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
...
...
@@ -632,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader):
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
...
@@ -640,18 +635,10 @@ class ShardedStateLoader(BaseModelLoader):
model_config
.
revision
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
t
orch
.
device
(
device_config
.
device
)
:
with
t
arget_
device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# When attention modules need to process weights after
# currently only used by MLA
module
.
process_weights_after_loading
(
model_config
.
dtype
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
...
...
@@ -1401,16 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
self
.
_get_weights_iterator
(
model_weights
,
model_config
.
revision
))
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# When attention modules need to process weights after
# currently only used by MLA
module
.
process_weights_after_loading
(
model_config
.
dtype
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment