Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9532c498
Unverified
Commit
9532c498
authored
Mar 14, 2025
by
Lucas Wilkinson
Committed by
GitHub
Mar 13, 2025
Browse files
[Attention] MLA get rid of materialization (#14770)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
0c2af17c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
499 deletions
+117
-499
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+57
-210
vllm/envs.py
vllm/envs.py
+0
-19
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+2
-57
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+58
-213
No files found.
vllm/attention/backends/mla/common.py
View file @
9532c498
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N
*
P]
W_UK project kv_c to k_nope shape [Lkv, N
,
P]
W_KR project h_t to k_pe shape [H,
N *
R]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N
*
V]
W_UV project kv_c to v shape [Lkv, N
,
V]
W_O project v to h_t shape [N * V, H]
W_O project v to h_t shape [N * V, H]
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P)
k_nope = (kv_c @ W_UK
.view(Lkv, N * P)
).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V)
v = (kv_c @ W_UV
.view(Lkv, N * V)
).view(Skv, N, V)
// MHA with QK headdim = P + R
// MHA with QK headdim = P + R
// V headdim = V
// V headdim = V
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime
Runtime
q_c = h_t @ W_DQ
q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
...
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
...
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
spda_o = scaled_dot_product_attention(
torch.cat([q
_latent
, q_pe], dim=-1),
torch.cat([q
l_nope
, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
kv_c
)
)
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
## Chunked Prefill
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
new_k_nope = (new_kv_c @ W_UK
.view(Lkv, N * P)
).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
new_v = (new_kv_c @ W_UV
.view(Lkv, N * V)
).view(Sq, N, V)
// MHA between queries and new KV
// MHA between queries and new KV
// with QK headdim = P + R
// with QK headdim = P + R
...
@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
...
@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type
,
TypeVar
)
Type
,
TypeVar
)
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm
import
envs
...
@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
get_flash_attn_version
,
get_flash_attn_version
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
...
@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
fa_version
=
self
.
vllm_flash_attn_version
)
fa_version
=
self
.
vllm_flash_attn_version
)
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
# Convert from (B, N, L) to (N, B, L)
if
is_fp8
(
self
.
W_UV_O
):
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
x
=
torch
.
bmm
(
x
,
self
.
W_UV
)
self
.
reqaunt_input_group_shape
,
# Convert from (N, B, V) to (B, N * V)
self
.
reqaunt_weight_group_shape
)
x
=
x
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
else
:
return
self
.
o_proj
(
x
)[
0
]
output_parallel
=
torch
.
matmul
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
)
# Return `ql_nope`, `q_pe`
if
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
return
output
else
:
x
=
torch
.
einsum
(
"bnl,lnv->bnv"
,
x
,
self
.
W_UV
)
return
self
.
o_proj
(
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
))[
0
]
def
_q_proj_and_k_up_proj
(
self
,
x
):
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
q_nope
,
q_pe
=
self
.
q_proj
(
x
)[
0
]
\
if
is_fp8
(
self
.
W_Q_UK
):
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
\
return
self
.
fp8_linear_generic
.
apply
(
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
return
torch
.
matmul
(
x
,
self
.
W_Q_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
else
:
x
=
torch
.
matmul
(
x
,
self
.
W_Q
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
)
return
torch
.
einsum
(
"bnp,lnp->bnl"
,
x
,
self
.
W_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
return
ql_nope
.
transpose
(
0
,
1
),
q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return
(
1
,
weight_block_size
[
-
1
]),
weight_block_size
else
:
return
(
-
1
,
-
1
),
(
-
1
,
-
1
)
# per-tensor, per-tensor
elif
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy
=
layer
.
scheme
.
strategy
if
strategy
==
QuantizationStrategy
.
TENSOR
:
return
(
1
,
-
1
),
(
-
1
,
-
1
)
# per-token, per-tensor
elif
strategy
==
QuantizationStrategy
.
CHANNEL
:
return
(
1
,
-
1
),
(
-
1
,
1
)
# per-token, per-channel
else
:
raise
NotImplementedError
(
f
"QuantizationStrategy.
{
strategy
}
is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1"
)
else
:
raise
NotImplementedError
(
"Can't determine scale group shapes for "
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
get_layer_weight
(
layer
):
def
get_layer_weight
(
layer
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
...
@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return
dequant_weights
.
T
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
# we currently do not have quantized bmm's which are needed for
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
...
@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
).
T
\
# Convert from (L, N, V) to (N, L, V)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q
=
q_proj_weight
[...,
:
self
.
qk_nope_head_dim
]
self
.
W_QR
=
q_proj_weight
[...,
self
.
qk_nope_head_dim
:]
\
.
flatten
(
start_dim
=
1
).
contiguous
()
# W_QR is small so for simplicity we dont bother requantizing it
self
.
W_QR
=
self
.
W_QR
.
to
(
act_dtype
)
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
requantization_enabled
=
not
envs
.
VLLM_MLA_DISABLE_REQUANTIZATION
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape
,
requant_weight_group_shape
=
\
get_scale_group_shapes_for_fp8
(
self
.
q_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
kv_b_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
o_proj
)
self
.
reqaunt_input_group_shape
=
requant_input_group_shape
self
.
reqaunt_weight_group_shape
=
requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self
.
W_UK
,
self
.
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK
=
torch
.
einsum
(
"qnd,lnd -> qnl"
,
W_Q
,
W_UK
)
\
.
flatten
(
start_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK_scales
=
W_Q_UK_scales
.
T
.
contiguous
()
else
:
self
.
W_Q_UK
=
W_Q_UK
.
to
(
act_dtype
)
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
W_UV_O
=
torch
.
einsum
(
"lnd,hnd -> nlh"
,
W_UV
,
W_O
)
\
.
flatten
(
start_dim
=
0
,
end_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O_scales
=
W_UV_O_scales
.
T
.
contiguous
()
else
:
self
.
W_UV_O
=
W_UV_O
.
to
(
act_dtype
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
else
:
if
is_fp8
(
weight_dtype
):
raise
NotImplementedError
(
"Currently fp8 requires matrix absorption"
)
self
.
W_UV
=
W_UV
self
.
W_UK
=
W_UK
self
.
W_Q
=
W_Q
.
flatten
(
start_dim
=
1
)
def
_compute_prefill_context
(
def
_compute_prefill_context
(
self
,
self
,
...
@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
q_nope
:
torch
.
Tensor
,
q
l
_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
T
,
...
@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
prefill_k_c_normed
=
k_c_normed
[:
num_prefill_tokens
]
prefill_k_c_normed
=
k_c_normed
[:
num_prefill_tokens
]
if
has_decode
:
if
has_decode
:
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_ql_nope
,
decode_q_pe
=
\
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode_input_positions
,
decode_q_pe
,
decode_k_pe
)
decode_input_positions
,
decode_q_pe
,
decode_k_pe
)
...
@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
has_decode
:
if
has_decode
:
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_q_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_q
l
_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
return
output
return
output
vllm/envs.py
View file @
9532c498
...
@@ -84,8 +84,6 @@ if TYPE_CHECKING:
...
@@ -84,8 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
bool
=
True
VLLM_MLA_DISABLE_REQUANTIZATION
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
...
@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE"
:
"VLLM_MLA_DISABLE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE"
,
"0"
))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION"
,
"1"
))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE_REQUANTIZATION"
,
"0"
))),
# If set, vLLM will use the Triton implementation of moe_align_block_size,
# If set, vLLM will use the Triton implementation of moe_align_block_size,
# i.e. moe_align_block_size_triton in fused_moe.py.
# i.e. moe_align_block_size_triton in fused_moe.py.
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
:
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
:
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
9532c498
...
@@ -13,10 +13,9 @@ import triton.language as tl
...
@@ -13,10 +13,9 @@ import triton.language as tl
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
_normalize_quant_group_shape
,
scaled_dequantize
)
scaled_dequantize
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
CUTLASS_BLOCK_FP8_SUPPORTED
)
cutlass_fp8_supported
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -101,60 +100,6 @@ direct_register_custom_op(
...
@@ -101,60 +100,6 @@ direct_register_custom_op(
)
)
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearGenericOp
:
def
__init__
(
self
,
cutlass_fp8_supported
:
bool
=
cutlass_fp8_supported
(),
cutlass_block_fp8_supported
:
bool
=
cutlass_block_fp8_supported
(),
):
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
cutlass_fp8_supported
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_group_shape
:
Tuple
[
int
,
int
],
weight_group_shape
:
Tuple
[
int
,
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
# static scale if one
)
->
torch
.
Tensor
:
# View input as 2D matrix for fp8 methods
input
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
weight_group_shape
=
_normalize_quant_group_shape
(
\
weight
,
weight_group_shape
)
input_group_shape
=
_normalize_quant_group_shape
(
input
,
input_group_shape
)
def
is_dim_blocked
(
dim
,
shape
,
group_shape
):
return
group_shape
<
shape
[
dim
]
and
group_shape
>
1
if
is_dim_blocked
(
0
,
weight
.
shape
,
weight_group_shape
[
0
])
\
and
is_dim_blocked
(
1
,
weight
.
shape
,
weight_group_shape
[
1
])
and
\
input_group_shape
==
(
1
,
weight_group_shape
[
1
]):
return
apply_w8a8_block_fp8_linear
(
input
,
weight
,
list
(
weight_group_shape
),
weight_scale
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
)
else
:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return
self
.
fp8_linear
.
apply
(
input
,
weight
.
T
,
weight_scale
.
T
,
use_per_token_if_dynamic
=
\
(
input_group_shape
==
(
1
,
input
.
shape
[
1
])))
def
input_to_float8
(
def
input_to_float8
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
dtype
:
Optional
[
torch
.
dtype
]
=
None
...
...
vllm/v1/attention/backends/mla/common.py
View file @
9532c498
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N
*
P]
W_UK project kv_c to k_nope shape [Lkv, N
,
P]
W_KR project h_t to k_pe shape [H,
N *
R]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N
*
V]
W_UV project kv_c to v shape [Lkv, N
,
V]
W_O project v to h_t shape [N * V, H]
W_O project v to h_t shape [N * V, H]
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P)
k_nope = (kv_c @ W_UK
.view(Lkv, N * P)
).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V)
v = (kv_c @ W_UV
.view(Lkv, N * V)
).view(Skv, N, V)
// MHA with QK headdim = P + R
// MHA with QK headdim = P + R
// V headdim = V
// V headdim = V
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime
Runtime
q_c = h_t @ W_DQ
q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
...
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
...
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
spda_o = scaled_dot_product_attention(
torch.cat([q
_latent
, q_pe], dim=-1),
torch.cat([q
l_nope
, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
kv_c
)
)
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
## Chunked Prefill
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
new_k_nope = (new_kv_c @ W_UK
.view(Lkv, N * P)
).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
new_v = (new_kv_c @ W_UV
.view(Lkv, N * V)
).view(Sq, N, V)
// MHA between queries and new KV
// MHA between queries and new KV
// with QK headdim = P + R
// with QK headdim = P + R
...
@@ -198,30 +190,17 @@ from dataclasses import dataclass
...
@@ -198,30 +190,17 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadata
,
MLAAttentionImpl
)
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.utils
import
cdiv
,
round_down
...
@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
fa_version
=
self
.
vllm_flash_attn_version
)
fa_version
=
self
.
vllm_flash_attn_version
)
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
# Convert from (B, N, L) to (N, B, L)
if
is_fp8
(
self
.
W_UV_O
):
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
x
=
torch
.
bmm
(
x
,
self
.
W_UV
)
self
.
reqaunt_input_group_shape
,
# Convert from (N, B, V) to (B, N * V)
self
.
reqaunt_weight_group_shape
)
x
=
x
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
else
:
return
self
.
o_proj
(
x
)[
0
]
output_parallel
=
torch
.
matmul
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
)
# Return `ql_nope`, `q_pe`
if
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
return
output
else
:
x
=
torch
.
einsum
(
"bnl,lnv->bnv"
,
x
,
self
.
W_UV
)
return
self
.
o_proj
(
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
))[
0
]
def
_q_proj_and_k_up_proj
(
self
,
x
):
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
q_nope
,
q_pe
=
self
.
q_proj
(
x
)[
0
]
\
if
is_fp8
(
self
.
W_Q_UK
):
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
\
return
self
.
fp8_linear_generic
.
apply
(
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
return
torch
.
matmul
(
x
,
self
.
W_Q_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
else
:
x
=
torch
.
matmul
(
x
,
self
.
W_Q
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
)
return
torch
.
einsum
(
"bnp,lnp->bnl"
,
x
,
self
.
W_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
return
ql_nope
.
transpose
(
0
,
1
),
q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return
(
1
,
weight_block_size
[
-
1
]),
weight_block_size
else
:
return
(
-
1
,
-
1
),
(
-
1
,
-
1
)
# per-tensor, per-tensor
elif
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy
=
layer
.
scheme
.
strategy
if
strategy
==
QuantizationStrategy
.
TENSOR
:
return
(
1
,
-
1
),
(
-
1
,
-
1
)
# per-token, per-tensor
elif
strategy
==
QuantizationStrategy
.
CHANNEL
:
return
(
1
,
-
1
),
(
-
1
,
1
)
# per-token, per-channel
else
:
raise
NotImplementedError
(
f
"QuantizationStrategy.
{
strategy
}
is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1"
)
else
:
raise
NotImplementedError
(
"Can't determine scale group shapes for "
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
get_layer_weight
(
layer
):
def
get_layer_weight
(
layer
):
if
hasattr
(
layer
,
"weight"
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
return
layer
.
weight
for
attr
in
WEIGHT_NAMES
:
elif
hasattr
(
layer
,
"qweight"
):
if
hasattr
(
layer
,
attr
):
return
layer
.
qweight
return
getattr
(
layer
,
attr
)
else
:
raise
AttributeError
(
raise
AttributeError
(
f
"Layer '
{
layer
}
' has neither weight nor qweight"
)
f
"Layer '
{
layer
}
' has no recognized weight attribute:"
f
"
{
WEIGHT_NAMES
}
."
)
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
...
@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return
dequant_weights
.
T
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
# we currently do not have quantized bmm's which are needed for
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
...
@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
).
T
\
# Convert from (L, N, V) to (N, L, V)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q
=
q_proj_weight
[...,
:
self
.
qk_nope_head_dim
]
self
.
W_QR
=
q_proj_weight
[...,
self
.
qk_nope_head_dim
:]
\
.
flatten
(
start_dim
=
1
).
contiguous
()
# W_QR is small so for simplicity we dont bother requantizing it
self
.
W_QR
=
self
.
W_QR
.
to
(
act_dtype
)
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
requantization_enabled
=
not
envs
.
VLLM_MLA_DISABLE_REQUANTIZATION
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape
,
requant_weight_group_shape
=
\
get_scale_group_shapes_for_fp8
(
self
.
q_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
kv_b_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
o_proj
)
self
.
reqaunt_input_group_shape
=
requant_input_group_shape
self
.
reqaunt_weight_group_shape
=
requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self
.
W_UK
,
self
.
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK
=
torch
.
einsum
(
"qnd,lnd -> qnl"
,
W_Q
,
W_UK
)
\
.
flatten
(
start_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK_scales
=
W_Q_UK_scales
.
T
.
contiguous
()
else
:
self
.
W_Q_UK
=
W_Q_UK
.
to
(
act_dtype
)
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
W_UV_O
=
torch
.
einsum
(
"lnd,hnd -> nlh"
,
W_UV
,
W_O
)
\
.
flatten
(
start_dim
=
0
,
end_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O_scales
=
W_UV_O_scales
.
T
.
contiguous
()
else
:
self
.
W_UV_O
=
W_UV_O
.
to
(
act_dtype
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
else
:
if
is_fp8
(
weight_dtype
):
raise
NotImplementedError
(
"Currently fp8 requires matrix absorption"
)
self
.
W_UV
=
W_UV
self
.
W_UK
=
W_UK
self
.
W_Q
=
W_Q
.
flatten
(
start_dim
=
1
)
def
_compute_prefill_context
(
def
_compute_prefill_context
(
self
,
self
,
...
@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
q_nope
:
torch
.
Tensor
,
q
l
_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
M
,
attn_metadata
:
M
,
...
@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_ql_nope
,
decode_q_pe
=
\
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
.
contiguous
(),
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
.
contiguous
(),
decode_k_pe
)
decode_k_pe
)
...
@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
if
has_decode
:
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_q_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_q
l
_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
return
output_padded
return
output_padded
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment