Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b01efa0b
Commit
b01efa0b
authored
Apr 14, 2025
by
zhuwenwen
Browse files
remove unused mla utils.py
parent
4a19cdf5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
549 deletions
+4
-549
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+4
-2
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+0
-547
No files found.
vllm/attention/backends/mla/common.py
View file @
b01efa0b
...
...
@@ -1297,7 +1297,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
_padded
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
...
...
@@ -1323,8 +1323,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
...
...
vllm/attention/backends/mla/utils.py
deleted
100644 → 0
View file @
4a19cdf5
# SPDX-License-Identifier: Apache-2.0
import
os
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
,
T
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsW8A8Fp8
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_linear_generic
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
except
ImportError
:
from
flash_attn
import
flash_attn_varlen_func
@
dataclass
class
MLACommonMetadata
(
AttentionMetadata
):
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
class
MLACommonImpl
(
MLAAttentionImpl
[
T
],
Generic
[
T
]):
"""
Common class for implementing repeated parts
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
similar to multi-query attention.
* The dataflow is as follows,
* B: batch/sequence length
* H: hidden size
* N: number of attention heads
* Lq: latent dimension for Q
* Lkv: latent dimension for K/V
* P: nope dimension, P+R is the actual head_dim in common attention.
* R: rope dimension, this slide of the head_dim goes through rope.
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q
#
# Outside the MLA attention backend
#
1. The hidden states (B, H) are projected down into cq (B, Lq) and
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.
#
# Inside the MLA attention backend
#
* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
8. Attention is computued with q, k, v.
9. The attention computation returns (B, N, V), which is projected back
to (B, H) using out projection.
* if decode:
3. Here's the change, we do not perform up the full up projection for
q_c, and there is no up projection at all for kv_c. This is
achieved by the technique of "weight absorption". The paper says
"Fortunately, due to the associative law of matrix multiplication,
we can absorb WUK into WUQ, and WUV into WO"
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
* We can precompute the product of W_UQ and W_UK into
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
attention.
* We can precompute the product of W_UV and W_O into
W_UV_O (N, Lkv, H), which is possible due to V@O as the
"epilogue" of attention
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
5. q_pe, k_pe are then passed through rotary embeddings.
6. kv_c and k_pe are concatenated and inserted into the cache
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
(B, N, Lkv).
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
9. The attention is computed with q, k, v. Note that we just performed
a MQA attention with (LKv+R) as our head dim.
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
which included the v and rope values.
11. The attention computation returns (B, N, Lkv), which is projected
back to (B, H) using W_UV_O.
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
rotary_emb
:
RotaryEmbedding
,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj
:
ColumnParallelLinear
,
kv_b_proj
:
ColumnParallelLinear
,
o_proj
:
RowParallelLinear
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
rotary_emb
=
rotary_emb
self
.
use_yarn_rope
=
isinstance
(
rotary_emb
,
DeepseekScalingRotaryEmbedding
)
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_UV_O
):
output_parallel
=
apply_fp8_linear_generic
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
,
self
.
W_UV_O_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
)
else
:
output_parallel
=
torch
.
matmul
(
x
.
flatten
(
start_dim
=
1
),
self
.
W_UV_O
)
if
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
return
output
else
:
x
=
torch
.
einsum
(
"bnl,lnv->bnv"
,
x
,
self
.
W_UV
)
return
self
.
o_proj
(
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
))[
0
]
def
_q_proj_and_k_up_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
is_fp8
(
self
.
W_Q_UK
):
return
apply_fp8_linear_generic
(
x
,
self
.
W_Q_UK
,
self
.
W_Q_UK_scales
,
self
.
reqaunt_input_group_shape
,
self
.
reqaunt_weight_group_shape
).
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
return
torch
.
matmul
(
x
,
self
.
W_Q_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
else
:
x
=
torch
.
matmul
(
x
,
self
.
W_Q
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
)
return
torch
.
einsum
(
"bnp,lnp->bnl"
,
x
,
self
.
W_UK
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return
(
1
,
weight_block_size
[
-
1
]),
weight_block_size
else
:
return
(
-
1
,
-
1
),
(
-
1
,
-
1
)
# per-tensor, per-tensor
elif
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy
=
layer
.
scheme
.
strategy
if
strategy
==
QuantizationStrategy
.
TENSOR
:
return
(
1
,
-
1
),
(
-
1
,
-
1
)
# per-token, per-tensor
elif
strategy
==
QuantizationStrategy
.
CHANNEL
:
return
(
1
,
-
1
),
(
-
1
,
1
)
# per-token, per-channel
else
:
raise
NotImplementedError
(
f
"QuantizationStrategy.
{
strategy
}
is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1"
)
else
:
raise
NotImplementedError
(
"Can't determine scale group shapes for "
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
get_layer_weight
(
layer
):
if
hasattr
(
layer
,
"weight"
):
return
layer
.
weight
elif
hasattr
(
layer
,
"qweight"
):
return
layer
.
qweight
else
:
raise
AttributeError
(
f
"Layer '
{
layer
}
' has neither weight nor qweight"
)
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye
=
torch
.
eye
(
layer
.
input_size_per_partition
,
dtype
=
act_dtype
,
device
=
get_layer_weight
(
layer
).
device
)
dequant_weights
=
layer
.
quant_method
.
apply
(
layer
,
eye
,
bias
=
None
)
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
if
self
.
use_llama_nn
and
isinstance
(
self
.
kv_b_proj
.
quant_method
,
UnquantizedLinearMethod
):
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
)
else
:
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)),
(
f
"
{
kv_b_proj_weight
.
shape
=
}
, "
f
"
{
self
.
kv_lora_rank
=
}
, "
f
"
{
self
.
num_heads
=
}
, "
f
"
{
self
.
qk_nope_head_dim
=
}
, "
f
"
{
self
.
v_head_dim
=
}
"
)
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
self
.
kv_lora_rank
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
self
.
use_llama_nn
and
isinstance
(
self
.
q_proj
.
quant_method
,
UnquantizedLinearMethod
):
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
else
:
q_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
q_proj
).
T
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q
=
q_proj_weight
[...,
:
self
.
qk_nope_head_dim
]
self
.
W_QR
=
q_proj_weight
[...,
self
.
qk_nope_head_dim
:]
\
.
flatten
(
start_dim
=
1
).
contiguous
()
# W_QR is small so for simplicity we dont bother requantizing it
self
.
W_QR
=
self
.
W_QR
.
to
(
act_dtype
)
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
requantization_enabled
=
not
envs
.
VLLM_MLA_DISABLE_REQUANTIZATION
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape
,
requant_weight_group_shape
=
\
get_scale_group_shapes_for_fp8
(
self
.
q_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
kv_b_proj
)
assert
(
requant_input_group_shape
,
requant_weight_group_shape
)
\
==
get_scale_group_shapes_for_fp8
(
self
.
o_proj
)
self
.
reqaunt_input_group_shape
=
requant_input_group_shape
self
.
reqaunt_weight_group_shape
=
requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self
.
W_UK
,
self
.
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK
=
torch
.
einsum
(
"qnd,lnd -> qnl"
,
W_Q
,
W_UK
)
\
.
flatten
(
start_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_Q_UK
,
W_Q_UK_scales
=
scaled_quantize
(
W_Q_UK
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform_fp8_dtype
)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_Q_UK
=
W_Q_UK
.
T
.
contiguous
()
self
.
W_Q_UK_scales
=
W_Q_UK_scales
.
T
.
contiguous
()
else
:
self
.
W_Q_UK
=
W_Q_UK
.
to
(
act_dtype
)
if
self
.
use_llama_nn
and
isinstance
(
self
.
o_proj
.
quant_method
,
UnquantizedLinearMethod
):
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
).
T
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
else
:
W_O
=
get_and_maybe_dequant_weights
(
self
.
o_proj
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
W_UV_O
=
torch
.
einsum
(
"lnd,hnd -> nlh"
,
W_UV
,
W_O
)
\
.
flatten
(
start_dim
=
0
,
end_dim
=
1
).
contiguous
()
if
is_fp8
(
weight_dtype
)
and
requantization_enabled
:
W_UV_O
,
W_UV_O_scales
=
scaled_quantize
(
W_UV_O
,
self
.
reqaunt_weight_group_shape
,
quant_dtype
=
current_platform_fp8_dtype
)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self
.
W_UV_O
=
W_UV_O
.
T
.
contiguous
()
self
.
W_UV_O_scales
=
W_UV_O_scales
.
T
.
contiguous
()
else
:
self
.
W_UV_O
=
W_UV_O
.
to
(
act_dtype
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
else
:
if
is_fp8
(
weight_dtype
):
raise
NotImplementedError
(
"Currently fp8 requires matrix absorption"
)
self
.
W_UV
=
W_UV
self
.
W_UK
=
W_UK
self
.
W_Q
=
W_Q
.
flatten
(
start_dim
=
1
)
@
abstractmethod
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
attn_metadata
:
T
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
abstractmethod
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
layer
:
AttentionLayer
,
hidden_states_or_q_c
:
torch
.
Tensor
,
# query in unified attn
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
not
None
:
raise
NotImplementedError
(
"output is not yet supported for MLAImplBase"
)
is_decode
=
attn_metadata
.
decode_metadata
is
not
None
is_prefill
=
attn_metadata
.
prefill_metadata
is
not
None
if
(
is_decode
and
is_prefill
):
raise
NotImplementedError
(
"chunked prefill is not supported for MLAImplBase"
)
# Restore head dim (for rotary embedding)
k_pe
=
k_pe
.
unsqueeze
(
1
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
if
is_decode
:
q_nope
=
self
.
_q_proj_and_k_up_proj
(
hidden_states_or_q_c
)
q_pe
=
torch
.
matmul
(
hidden_states_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
else
:
assert
is_prefill
q
=
self
.
q_proj
(
hidden_states_or_q_c
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
# TODO(lucas): there must be a nicer way to write this line
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
\
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
if
attn_metadata
.
prefill_metadata
is
not
None
:
return
self
.
_forward_prefill
(
q
,
k_c_normed
,
k_pe
,
attn_metadata
)
if
attn_metadata
.
decode_metadata
is
not
None
:
return
self
.
_forward_decode
(
q_nope
,
q_pe
,
kv_cache
,
attn_metadata
)
# Optional common flash-attn based prefill
def
_forward_prefill_flash
(
self
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
seq_start_loc
:
torch
.
Tensor
,
max_prefill_seq_len
:
int
,
)
->
torch
.
Tensor
:
kv_nope
=
self
.
kv_b_proj
(
k_c_normed
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
(
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
)],
value
=
0
)
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
attn_output
=
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
max_seqlen_q
=
max_prefill_seq_len
,
max_seqlen_k
=
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
attn_output
=
attn_output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
else
:
attn_output
=
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
seq_start_loc
,
cu_seqlens_k
=
seq_start_loc
,
max_seqlen_q
=
max_prefill_seq_len
,
max_seqlen_k
=
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
attn_output
=
attn_output
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
attn_output
)[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment