Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac4dedb1
Commit
ac4dedb1
authored
Jul 07, 2025
by
王敏
Browse files
[feat]支持v1 engine flashmla和mtp同时开启
parent
9e27b5e4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
134 additions
and
26 deletions
+134
-26
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+95
-11
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+2
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+34
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+0
-1
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
ac4dedb1
...
...
@@ -161,7 +161,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
Sc = chunk_end - chunk_start
_table
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
...
...
@@ -191,6 +191,9 @@ import functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
from
itertools
import
chain
import
numpy
as
np
import
torch
import
os
...
...
@@ -208,10 +211,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.utils
import
cdiv
,
round_down
,
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
s
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -385,6 +388,32 @@ class MLACommonMetadataBuilder(Generic[M]):
)
self
.
block_table
=
block_table
self
.
_use_spec_decode
=
False
self
.
pin_memory
=
is_pin_memory_available
()
self
.
_num_scheduled_tokens
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
torch
.
int32
,
device
=
runner
.
device
)
self
.
_num_scheduled_tokens_cpu_tensor
=
torch
.
zeros
(
(
scheduler_config
.
max_num_seqs
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
self
.
pin_memory
,
)
self
.
_num_scheduled_tokens_np
=
self
.
_num_scheduled_tokens_cpu_tensor
.
numpy
()
self
.
_seq_lens_minus
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
*
5
,
dtype
=
torch
.
int32
,
device
=
runner
.
device
)
self
.
_seq_lens_minus_cpu_tensor
=
torch
.
zeros
(
(
scheduler_config
.
max_num_seqs
*
5
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
self
.
pin_memory
,
)
self
.
_seq_lens_minus_np
=
self
.
_seq_lens_minus_cpu_tensor
.
numpy
()
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
# We now want to reorder the batch so that the "decode" requests are and
...
...
@@ -397,6 +426,8 @@ class MLACommonMetadataBuilder(Generic[M]):
prefills
=
[]
num_decode_tokens
=
0
num_prefill_tokens
=
0
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -404,12 +435,23 @@ class MLACommonMetadataBuilder(Generic[M]):
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if
num_tokens
==
1
:
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
else
:
# if num_tokens == 2 or num_tokens == 1:
# decodes.append(i)
# num_decode_tokens += num_tokens
# else:
# prefills.append(i)
# num_prefill_tokens += num_tokens
req_idx
=
input_batch
.
req_id_to_index
[
req_id
]
num_computed_tokens
=
input_batch
.
num_computed_tokens_cpu
[
req_idx
]
num_prompt_tokens
=
input_batch
.
num_prompt_tokens
[
req_idx
]
self
.
_num_scheduled_tokens_np
[
i
]
=
num_tokens
if
num_computed_tokens
<
num_prompt_tokens
:
prefills
.
append
(
i
)
num_prefill_tokens
+=
num_tokens
else
:
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
...
...
@@ -435,6 +477,11 @@ class MLACommonMetadataBuilder(Generic[M]):
input_batch
.
swap_states
(
prefills
[
i
-
1
],
decode_idx
)
modified_batch
=
True
# num_scheduled_tokens also need to be swapped
tmp
=
self
.
_num_scheduled_tokens_np
[
decode_idx
]
self
.
_num_scheduled_tokens_np
[
decode_idx
]
=
self
.
_num_scheduled_tokens_np
[
prefills
[
i
-
1
]]
self
.
_num_scheduled_tokens_np
[
prefills
[
i
-
1
]]
=
tmp
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
...
...
@@ -443,6 +490,12 @@ class MLACommonMetadataBuilder(Generic[M]):
self
.
_num_decode_tokens
=
num_decode_tokens
self
.
_num_prefill_tokens
=
num_prefill_tokens
self
.
_use_spec_decode
=
use_spec_decode
if
use_spec_decode
:
self
.
_num_scheduled_tokens
[:
len
(
input_batch
.
req_ids
)].
copy_
(
self
.
_num_scheduled_tokens_cpu_tensor
[:
len
(
input_batch
.
req_ids
)],
non_blocking
=
True
)
return
modified_batch
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
...
...
@@ -548,6 +601,37 @@ class MLACommonMetadataBuilder(Generic[M]):
decode_metadata
=
None
if
self
.
_num_decodes
>
0
:
if
self
.
_use_spec_decode
:
# generate block_table/seq_lens of mla in spec decoding scenarios
if
common_attn_metadata
.
num_rejected_tokens_tuple
is
None
:
repeats
=
self
.
_num_scheduled_tokens
[:
self
.
_num_decodes
]
repeats_cpu
=
self
.
_num_scheduled_tokens_np
[:
self
.
_num_decodes
]
else
:
repeats
=
self
.
_num_scheduled_tokens
[:
self
.
_num_decodes
]
-
\
common_attn_metadata
.
num_rejected_tokens_tuple
[
1
][:
self
.
_num_decodes
]
num_rejected_tokens
=
common_attn_metadata
.
num_rejected_tokens_tuple
[
0
][:
self
.
_num_decodes
]
repeats_cpu
=
self
.
_num_scheduled_tokens_np
[:
self
.
_num_decodes
]
-
\
np
.
array
(
num_rejected_tokens
)
self
.
_num_decode_tokens
-=
sum
(
num_rejected_tokens
)
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
self
.
_num_decodes
,
...],
repeats
,
dim
=
0
)
total_decode_tokens
=
np
.
sum
(
repeats_cpu
)
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
)
self
.
_seq_lens_minus_np
[:
total_decode_tokens
]
=
np
.
fromiter
(
chain
.
from_iterable
(
np
.
flip
(
np
.
arange
(
x
))
for
x
in
repeats_cpu
),
dtype
=
int
)
self
.
_seq_lens_minus
[:
total_decode_tokens
].
copy_
(
self
.
_seq_lens_minus_cpu_tensor
[:
total_decode_tokens
],
non_blocking
=
True
)
decode_seq_lens
=
decode_seq_lens
-
self
.
_seq_lens_minus
[:
total_decode_tokens
]
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
decode_block_table_tensor
,
seq_lens
=
decode_seq_lens
,
)
else
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
self
.
_num_decodes
,
...],
seq_lens
=
seq_lens
[:
self
.
_num_decodes
],
...
...
vllm/v1/attention/backends/utils.py
View file @
ac4dedb1
...
...
@@ -17,7 +17,8 @@ class CommonAttentionMetadata:
seq_lens
:
torch
.
Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_rejected_tokens_tuple
:
tuple
[
list
[
int
],
torch
.
Tensor
]
=
None
"""(batch_size,), record the rejected tokens number in cpu and gpu"""
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
...
...
vllm/v1/spec_decode/eagle.py
View file @
ac4dedb1
...
...
@@ -14,6 +14,7 @@ from vllm.model_executor.models import supports_multimodal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
(
CommonAttentionMetadata
,
FlashAttentionMetadata
)
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
,
MLACommonDecodeMetadata
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.utils
import
prepare_eagle_input_kernel
...
...
@@ -91,7 +92,9 @@ class EagleProposer:
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
# [batch_size]
num_rejected_tokens_tuple
:
tuple
[
list
[
int
],
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
...
...
@@ -138,7 +141,9 @@ class EagleProposer:
max_query_len
=
query_lens
.
max
().
item
()
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
)
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_rejected_tokens_tuple
=
num_rejected_tokens_tuple
)
assert
self
.
runner
is
not
None
...
...
@@ -210,6 +215,17 @@ class EagleProposer:
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
num_decodes
=
batch_size
attn_metadata
.
num_decode_tokens
=
batch_size
attn_metadata
.
num_prefills
=
0
block_table
=
self
.
runner
.
attn_metadata_builders
[
0
].
block_table
.
get_device_tensor
()[:
batch_size
,
...]
attn_metadata
.
decode
=
self
.
runner
.
attn_metadata_builders
[
0
].
_build_decode
(
block_table_tensor
=
block_table
,
seq_lens
=
(
seq_lens
+
1
),
)
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
...
...
@@ -229,12 +245,17 @@ class EagleProposer:
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
decode
.
seq_lens
+=
1
else
:
attn_metadata
.
seq_lens
+=
1
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
attn_metadata
.
seq_lens
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ac4dedb1
...
...
@@ -1441,6 +1441,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
block_table
=
None
num_rejected_tokens_tuple
=
None
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
...
...
@@ -1480,6 +1481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
num_rejected_tokens_tuple
=
(
num_rejected_tokens
,
num_rejected_tokens_tensor
)
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
...
...
@@ -1489,6 +1491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
num_rejected_tokens_tuple
=
num_rejected_tokens_tuple
)
spec_token_ids
=
draft_token_ids
.
tolist
()
...
...
vllm/worker/worker_base.py
View file @
ac4dedb1
...
...
@@ -28,7 +28,6 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase
,
ModelRunnerInputBase
)
torch
.
_C
.
_set_blas_preferred_backend
(
torch
.
_C
.
_BlasBackend
.
Cublas
)
logger
=
init_logger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment