Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb0acb6c
Unverified
Commit
fb0acb6c
authored
Mar 10, 2025
by
Simon Mo
Committed by
GitHub
Mar 10, 2025
Browse files
[Perf] Improve MLA on V1 (#14540)
Signed-off-by:
simon-mo
<
simon.mo@hey.com
>
parent
92b0ce2a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
27 deletions
+41
-27
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+41
-27
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
fb0acb6c
...
...
@@ -223,6 +223,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
try
:
...
...
@@ -471,18 +472,23 @@ class MLACommonMetadataBuilder(Generic[M]):
common_prefix_len
:
int
)
->
M
:
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device
=
self
.
runner
.
device
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
device
,
non_blocking
=
True
)
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
device
,
non_blocking
=
True
)
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
input_positions
=
self
.
runner
.
positions_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
max_query_len
=
seq_lens_cpu
.
max
().
item
()
prefill_metadata
=
None
if
self
.
_num_prefills
>
0
:
reqs_start
=
self
.
_num_decodes
# prefill_start
...
...
@@ -490,24 +496,22 @@ class MLACommonMetadataBuilder(Generic[M]):
context_lens_cpu
=
self
.
runner
.
input_batch
.
\
num_computed_tokens_cpu_tensor
[
reqs_start
:
num_reqs
]
context_lens
=
context_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
chunked_context_metadata
=
None
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
and
context_len
s
.
max
()
>
0
:
and
max_
context_len
_cpu
>
0
:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
num_prefills_with_context
=
(
context_lens
>
0
).
sum
().
item
()
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk
=
\
self
.
chunked_prefill_workspace_size
\
//
num_prefills_with_context
max_context_chunk
=
(
self
.
chunked_prefill_workspace_size
//
num_prefills_with_context_cpu
)
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
...
...
@@ -516,30 +520,35 @@ class MLACommonMetadataBuilder(Generic[M]):
self
.
page_size
)
assert
max_context_chunk
>
0
num_chunks
=
cdiv
(
context_len
s
.
max
()
,
max_context_chunk
)
num_chunks
=
cdiv
(
max_
context_len
_cpu
,
max_context_chunk
)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts
=
\
torch
.
arange
(
num_chunks
,
device
=
device
,
dtype
=
torch
.
int32
)
\
torch
.
arange
(
num_chunks
,
dtype
=
torch
.
int32
)
\
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
_num_prefills
)
\
*
max_context_chunk
chunk_ends
=
torch
.
min
(
context_lens
.
unsqueeze
(
0
),
chunk_ends
=
torch
.
min
(
context_lens
_cpu
.
unsqueeze
(
0
),
chunk_starts
+
max_context_chunk
)
chunk_seq_lens
=
(
chunk_ends
-
chunk_starts
).
clamp
(
min
=
0
)
_chunk_cu_seq_lens
=
chunk_seq_lens
.
cumsum
(
dim
=
1
).
to
(
torch
.
int32
)
zero
=
torch
.
zeros
(
num_chunks
,
cu_seq_lens_cpu
=
torch
.
zeros
(
num_chunks
,
self
.
_num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
device
).
unsqueeze
(
-
1
)
pin_memory
=
True
)
torch
.
cumsum
(
chunk_seq_lens
,
dim
=
1
,
out
=
cu_seq_lens_cpu
[:,
1
:],
dtype
=
torch
.
int32
)
chunked_context_metadata
=
\
MLACommonPrefillMetadata
.
ChunkedContextMetadata
(
cu_seq_lens
=
torch
.
cat
(
[
zero
,
_chunk_cu_seq_lens
],
dim
=
1
),
starts
=
chunk_starts
,
cu_seq_lens
=
cu_seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
),
starts
=
chunk_starts
.
to
(
device
,
non_blocking
=
True
),
seq_tot
=
chunk_seq_lens
.
sum
(
dim
=
1
).
tolist
(),
max_seq_lens
=
chunk_seq_lens
.
max
(
dim
=
1
).
values
.
tolist
(),
workspace
=
self
.
chunked_prefill_workspace
,
...
...
@@ -553,7 +562,7 @@ class MLACommonMetadataBuilder(Generic[M]):
block_table
=
block_table
[
reqs_start
:,
...],
query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
],
max_query_len
=
seq_lens
[
reqs_start
:].
max
().
item
()
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
)
...
...
@@ -629,7 +638,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
self
.
rotary_emb
=
rotary_emb
.
_forward_method
self
.
rotary_emb
=
rotary_emb
.
forward_native
if
current_platform
.
is_cuda
():
self
.
rotary_emb
=
rotary_emb
.
forward_cuda
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
...
...
@@ -1043,17 +1054,20 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
,
decode_k_pe
)
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
.
contiguous
(),
decode_k_pe
)
if
has_prefill
:
assert
attn_metadata
.
prefill
is
not
None
prefill_q
=
self
.
q_proj
(
prefill_hs_or_q_c
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
prefill_q_pe
=
prefill_q
[...,
self
.
qk_nope_head_dim
:]
prefill_q_pe
[...],
prefill_k_pe
[...]
=
self
.
rotary_emb
(
attn_metadata
.
prefill
.
input_positions
,
prefill_q_pe
,
prefill_k_pe
)
attn_metadata
.
prefill
.
input_positions
,
prefill_q_pe
.
contiguous
(),
prefill_k_pe
)
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment