Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9532c498
Unverified
Commit
9532c498
authored
Mar 14, 2025
by
Lucas Wilkinson
Committed by
GitHub
Mar 13, 2025
Browse files
[Attention] MLA get rid of materialization (#14770)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
0c2af17c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
499 deletions
+117
-499
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+57
-210
vllm/envs.py
vllm/envs.py
+0
-19
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+2
-57
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+58
-213
No files found.
vllm/attention/backends/mla/common.py
View file @
9532c498
...
...
@@ -7,22 +7,22 @@ First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
...
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N
*
P]
W_KR project h_t to k_pe shape [H,
N *
R]
W_UV project kv_c to v shape [Lkv, N
*
V]
W_UK project kv_c to k_nope shape [Lkv, N
,
P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N
,
V]
W_O project v to h_t shape [N * V, H]
...
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V)
k_nope = (kv_c @ W_UK
.view(Lkv, N * P)
).view(Skv, N, P)
v = (kv_c @ W_UV
.view(Lkv, N * V)
).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
...
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime
q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
...
...
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([q
_latent
, q_pe], dim=-1),
torch.cat([q
l_nope
, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
...
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
new_k_nope = (new_kv_c @ W_UK
.view(Lkv, N * P)
).view(Sq, N, P)
new_v = (new_kv_c @ W_UV
.view(Lkv, N * V)
).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
...
...
@@ -171,17 +163,17 @@ for chunk_idx in range(cdiv(C, MCC)):
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
...
...
@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type
,
TypeVar
)
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
...
...
@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
get_flash_attn_version
,
is_block_tables_empty
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
...
...
@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
...
@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
fa_version
=
self
.
vllm_flash_attn_version
)
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_UV_O
):
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
)
else
:
output_parallel
=
torch
.
matmul
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
)
if
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
return
output
else
:
x
=
torch
.
einsum
(
"bnl,lnv->bnv"
,
x
,
self
.
W_UV
)
return
self
.
o_proj
(
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
))[
0
]
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x
=
torch
.
bmm
(
x
,
self
.
W_UV
)
# Convert from (N, B, V) to (B, N * V)
x
=
x
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
return
self
.
o_proj
(
x
)[
0
]
# Return `ql_nope`, `q_pe`
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_Q_UK
):
return
self
.
fp8_linear_generic
.
apply
(
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
return
torch
.
matmul
(
x
,
self
.
W_Q_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
else
:
x
=
torch
.
matmul
(
x
,
self
.
W_Q
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
)
return
torch
.
einsum
(
"bnp,lnp->bnl"
,
x
,
self
.
W_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
q_nope
,
q_pe
=
self
.
q_proj
(
x
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
return
ql_nope
.
transpose
(
0
,
1
),
q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return
(
1
,
weight_block_size
[
-
1
]),
weight_block_size
else
:
return
(
-
1
,
-
1
),
(
-
1
,
-
1
)
# per-tensor, per-tensor
elif
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy
=
layer
.
scheme
.
strategy
if
strategy
==
QuantizationStrategy
.
TENSOR
:
return
(
1
,
-
1
),
(
-
1
,
-
1
)
# per-token, per-tensor
elif
strategy
==
QuantizationStrategy
.
CHANNEL
:
return
(
1
,
-
1
),
(
-
1
,
1
)
# per-token, per-channel
else
:
raise
NotImplementedError
(
f
"QuantizationStrategy.
{
strategy
}
is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1"
)
else
:
raise
NotImplementedError
(
"Can't determine scale group shapes for "
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
get_layer_weight
(
layer
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
...
...
@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return
dequant_weights
.
T
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
...
...
@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
).
T
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q
=
q_proj_weight
[...,
:
self
.
qk_nope_head_dim
]
self
.
W_QR
=
q_proj_weight
[...,
self
.
qk_nope_head_dim
:]
\
.
flatten
(
start_dim
=
1
).
contiguous
()
# W_QR is small so for simplicity we dont bother requantizing it
self
.
W_QR
=
self
.
W_QR
.
to
(
act_dtype
)
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
requantization_enabled
=
not
envs
.
VLLM_MLA_DISABLE_REQUANTIZATION
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape
,
requant_weight_group_shape
=
\
get_scale_group_shapes_for_fp8
(
self
.
q_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
kv_b_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
o_proj
)
self
.
reqaunt_input_group_shape
=
requant_input_group_shape
self
.
reqaunt_weight_group_shape
=
requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self
.
W_UK
,
self
.
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK
=
torch
.
einsum
(
"qnd,lnd -> qnl"
,
W_Q
,
W_UK
)
\
.
flatten
(
start_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK_scales
=
W_Q_UK_scales
.
T
.
contiguous
()
else
:
self
.
W_Q_UK
=
W_Q_UK
.
to
(
act_dtype
)
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
W_UV_O
=
torch
.
einsum
(
"lnd,hnd -> nlh"
,
W_UV
,
W_O
)
\
.
flatten
(
start_dim
=
0
,
end_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O_scales
=
W_UV_O_scales
.
T
.
contiguous
()
else
:
self
.
W_UV_O
=
W_UV_O
.
to
(
act_dtype
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
else
:
if
is_fp8
(
weight_dtype
):
raise
NotImplementedError
(
"Currently fp8 requires matrix absorption"
)
self
.
W_UV
=
W_UV
self
.
W_UK
=
W_UK
self
.
W_Q
=
W_Q
.
flatten
(
start_dim
=
1
)
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_compute_prefill_context
(
self
,
...
...
@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
@
abstractmethod
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q
l
_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
...
...
@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
prefill_k_c_normed
=
k_c_normed
[:
num_prefill_tokens
]
if
has_decode
:
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_ql_nope
,
decode_q_pe
=
\
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode_input_positions
,
decode_q_pe
,
decode_k_pe
)
...
...
@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
has_decode
:
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_q_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_q
l
_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
return
output
vllm/envs.py
View file @
9532c498
...
...
@@ -84,8 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
bool
=
True
VLLM_MLA_DISABLE_REQUANTIZATION
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
...
...
@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE"
,
"0"
))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION"
,
"1"
))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE_REQUANTIZATION"
,
"0"
))),
# If set, vLLM will use the Triton implementation of moe_align_block_size,
# i.e. moe_align_block_size_triton in fused_moe.py.
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
:
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
9532c498
...
...
@@ -13,10 +13,9 @@ import triton.language as tl
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
_normalize_quant_group_shape
,
scaled_dequantize
)
scaled_dequantize
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
)
CUTLASS_BLOCK_FP8_SUPPORTED
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -101,60 +100,6 @@ direct_register_custom_op(
)
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class
Fp8LinearGenericOp
:
def
__init__
(
self
,
cutlass_fp8_supported
:
bool
=
cutlass_fp8_supported
(),
cutlass_block_fp8_supported
:
bool
=
cutlass_block_fp8_supported
(),
):
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
self
.
fp8_linear
=
Fp8LinearOp
(
cutlass_fp8_supported
=
cutlass_fp8_supported
)
def
apply
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_group_shape
:
Tuple
[
int
,
int
],
weight_group_shape
:
Tuple
[
int
,
int
],
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
# static scale if one
)
->
torch
.
Tensor
:
# View input as 2D matrix for fp8 methods
input
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
weight_group_shape
=
_normalize_quant_group_shape
(
\
weight
,
weight_group_shape
)
input_group_shape
=
_normalize_quant_group_shape
(
input
,
input_group_shape
)
def
is_dim_blocked
(
dim
,
shape
,
group_shape
):
return
group_shape
<
shape
[
dim
]
and
group_shape
>
1
if
is_dim_blocked
(
0
,
weight
.
shape
,
weight_group_shape
[
0
])
\
and
is_dim_blocked
(
1
,
weight
.
shape
,
weight_group_shape
[
1
])
and
\
input_group_shape
==
(
1
,
weight_group_shape
[
1
]):
return
apply_w8a8_block_fp8_linear
(
input
,
weight
,
list
(
weight_group_shape
),
weight_scale
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
)
else
:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return
self
.
fp8_linear
.
apply
(
input
,
weight
.
T
,
weight_scale
.
T
,
use_per_token_if_dynamic
=
\
(
input_group_shape
==
(
1
,
input
.
shape
[
1
])))
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
...
...
vllm/v1/attention/backends/mla/common.py
View file @
9532c498
...
...
@@ -21,7 +21,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
...
...
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N
*
P]
W_KR project h_t to k_pe shape [H,
N *
R]
W_UV project kv_c to v shape [Lkv, N
*
V]
W_UK project kv_c to k_nope shape [Lkv, N
,
P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N
,
V]
W_O project v to h_t shape [N * V, H]
...
...
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V)
k_nope = (kv_c @ W_UK
.view(Lkv, N * P)
).view(Skv, N, P)
v = (kv_c @ W_UV
.view(Lkv, N * V)
).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
...
...
@@ -79,7 +79,7 @@ spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
)
return spda_o @ W_O
NOTE: in the actual code,
...
...
@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime
q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
...
...
@@ -116,29 +106,31 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([q
_latent
, q_pe], dim=-1),
torch.cat([q
l_nope
, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
...
...
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
new_k_nope = (new_kv_c @ W_UK
.view(Lkv, N * P)
).view(Sq, N, P)
new_v = (new_kv_c @ W_UV
.view(Lkv, N * V)
).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
...
...
@@ -160,7 +152,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
new_v,
casual=True,
return_softmax_lse=True
)
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
...
...
@@ -198,30 +190,17 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
Fp8LinearGenericOp
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
...
...
@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
fp8_linear_generic
=
Fp8LinearGenericOp
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
...
@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
fa_version
=
self
.
vllm_flash_attn_version
)
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_UV_O
):
output_parallel
=
self
.
fp8_linear_generic
.
apply
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
)
else
:
output_parallel
=
torch
.
matmul
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
)
if
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
return
output
else
:
x
=
torch
.
einsum
(
"bnl,lnv->bnv"
,
x
,
self
.
W_UV
)
return
self
.
o_proj
(
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
))[
0
]
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x
=
torch
.
bmm
(
x
,
self
.
W_UV
)
# Convert from (N, B, V) to (B, N * V)
x
=
x
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
return
self
.
o_proj
(
x
)[
0
]
# Return `ql_nope`, `q_pe`
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_Q_UK
):
return
self
.
fp8_linear_generic
.
apply
(
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
return
torch
.
matmul
(
x
,
self
.
W_Q_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
else
:
x
=
torch
.
matmul
(
x
,
self
.
W_Q
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
)
return
torch
.
einsum
(
"bnp,lnp->bnl"
,
x
,
self
.
W_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
q_nope
,
q_pe
=
self
.
q_proj
(
x
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
return
ql_nope
.
transpose
(
0
,
1
),
q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return
(
1
,
weight_block_size
[
-
1
]),
weight_block_size
else
:
return
(
-
1
,
-
1
),
(
-
1
,
-
1
)
# per-tensor, per-tensor
elif
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy
=
layer
.
scheme
.
strategy
if
strategy
==
QuantizationStrategy
.
TENSOR
:
return
(
1
,
-
1
),
(
-
1
,
-
1
)
# per-token, per-tensor
elif
strategy
==
QuantizationStrategy
.
CHANNEL
:
return
(
1
,
-
1
),
(
-
1
,
1
)
# per-token, per-channel
else
:
raise
NotImplementedError
(
f
"QuantizationStrategy.
{
strategy
}
is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1"
)
else
:
raise
NotImplementedError
(
"Can't determine scale group shapes for "
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
get_layer_weight
(
layer
):
if
hasattr
(
layer
,
"weight"
)
:
return
layer
.
weight
el
if
hasattr
(
layer
,
"qweight"
):
return
layer
.
qweight
else
:
raise
A
ttribute
Error
(
f
"Layer '
{
layer
}
' has neither weight nor qweight
"
)
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight
_packed
"
)
for
attr
in
WEIGHT_NAMES
:
if
hasattr
(
layer
,
attr
):
return
getattr
(
layer
,
attr
)
raise
AttributeError
(
f
"Layer '
{
layer
}
' has no recognized weight a
ttribute
:"
f
"
{
WEIGHT_NAMES
}
.
"
)
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
...
...
@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return
dequant_weights
.
T
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
...
...
@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
).
T
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q
=
q_proj_weight
[...,
:
self
.
qk_nope_head_dim
]
self
.
W_QR
=
q_proj_weight
[...,
self
.
qk_nope_head_dim
:]
\
.
flatten
(
start_dim
=
1
).
contiguous
()
# W_QR is small so for simplicity we dont bother requantizing it
self
.
W_QR
=
self
.
W_QR
.
to
(
act_dtype
)
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
requantization_enabled
=
not
envs
.
VLLM_MLA_DISABLE_REQUANTIZATION
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape
,
requant_weight_group_shape
=
\
get_scale_group_shapes_for_fp8
(
self
.
q_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
kv_b_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
o_proj
)
self
.
reqaunt_input_group_shape
=
requant_input_group_shape
self
.
reqaunt_weight_group_shape
=
requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self
.
W_UK
,
self
.
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK
=
torch
.
einsum
(
"qnd,lnd -> qnl"
,
W_Q
,
W_UK
)
\
.
flatten
(
start_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK_scales
=
W_Q_UK_scales
.
T
.
contiguous
()
else
:
self
.
W_Q_UK
=
W_Q_UK
.
to
(
act_dtype
)
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
W_UV_O
=
torch
.
einsum
(
"lnd,hnd -> nlh"
,
W_UV
,
W_O
)
\
.
flatten
(
start_dim
=
0
,
end_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform
.
fp8_dtype
())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O_scales
=
W_UV_O_scales
.
T
.
contiguous
()
else
:
self
.
W_UV_O
=
W_UV_O
.
to
(
act_dtype
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
else
:
if
is_fp8
(
weight_dtype
):
raise
NotImplementedError
(
"Currently fp8 requires matrix absorption"
)
self
.
W_UV
=
W_UV
self
.
W_UK
=
W_UK
self
.
W_Q
=
W_Q
.
flatten
(
start_dim
=
1
)
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_compute_prefill_context
(
self
,
...
...
@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@
abstractmethod
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q
l
_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
M
,
...
...
@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_ql_nope
,
decode_q_pe
=
\
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
.
contiguous
(),
decode_k_pe
)
...
...
@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_decode
:
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_q_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_q
l
_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
return
output_padded
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment