Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6aa3d19
Commit
f6aa3d19
authored
Dec 05, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.11.0-dev-wm-1205' into 'v0.11.0-dev'
去掉无效代码 See merge request dcutoolkit/deeplearing/vllm!285
parents
b8412df6
7343379a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
114 deletions
+27
-114
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+17
-103
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+10
-11
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
f6aa3d19
...
...
@@ -186,17 +186,13 @@ for chunk_idx in range(cdiv(C, MCC)):
return curr_o @ W_O
"""
import
os
import
functools
from
abc
import
abstractmethod
import
numpy
as
np
from
dataclasses
import
dataclass
,
field
from
typing
import
Generic
,
Optional
,
TypeVar
,
Union
import
torch
import
os
from
tqdm
import
tqdm
import
vllm.envs
as
envs
...
...
@@ -558,14 +554,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device
=
device
,
)
self
.
block_table
=
block_table
self
.
use_spec_decode
=
False
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
def
_build_fi_prefill_wrappers
(
self
,
prefill
:
FlashInferPrefillMetadata
):
qo_indptr
=
prefill
.
query_start_loc
...
...
@@ -659,31 +647,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m
=
common_attn_metadata
# assert m.num_reqs <= (m.num_actual_tokens *
# self.reorder_batch_threshold), \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.max_query_len <= self.reorder_batch_threshold # decode only
self
.
use_spec_decode
=
m
.
num_speculative_tokens
>
0
# support for cudagraph spec docoding
if
self
.
use_spec_decode
:
for
i
in
range
(
m
.
num_reqs
):
self
.
num_scheduled_tokens_np
[
i
]
=
m
.
num_actual_tokens
//
m
.
num_reqs
if
self
.
spec_decode_block_table_tensor
is
None
:
max_num_reqs
=
m
.
seq_lens
.
shape
[
0
]
block_table_tensor
=
self
.
block_table
.
get_device_tensor
()
tokens_per_seq
=
1
+
m
.
num_speculative_tokens
self
.
spec_decode_block_table_tensor
=
torch
.
zeros
((
block_table_tensor
.
shape
[
0
]
*
tokens_per_seq
,
block_table_tensor
.
shape
[
1
]),
dtype
=
block_table_tensor
.
dtype
,
device
=
m
.
seq_lens
.
device
)
self
.
spec_decode_seq_lens
=
torch
.
zeros
(
max_num_reqs
*
tokens_per_seq
,
dtype
=
m
.
seq_lens
.
dtype
,
device
=
m
.
seq_lens
.
device
)
assert
m
.
num_reqs
<=
(
m
.
num_actual_tokens
*
self
.
reorder_batch_threshold
),
\
"MLA only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
assert
m
.
max_query_len
<=
self
.
reorder_batch_threshold
# decode only
return
self
.
build
(
0
,
m
)
def
build
(
self
,
...
...
@@ -699,15 +669,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device
=
self
.
device
block_table
=
self
.
block_table
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
if
slot_mapping
is
None
:
block_table
.
slot_mapping
[:
num_tokens
].
copy_
(
block_table
.
slot_mapping_cpu
[:
num_tokens
],
non_blocking
=
True
)
block_table
.
slot_mapping
[
num_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_tokens
]
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
...
...
@@ -873,65 +836,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
-
prefill_query_start_loc
[:
-
1
]
prefill_metadata
.
cudnn_workspace
=
self
.
cudnn_workspace
# TODO @ wangming
decode_metadata
=
None
# if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
# cu_num_blocks = np.cumsum(query_lens)
# virtual_batches = cu_num_blocks[-1]
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# rarange = np.repeat(query_lens, query_lens) - arange - 1
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
seq_lens_cpu
=
seq_lens_cpu
[:
num_decodes
],
seq_lens_device
=
seq_lens
[:
num_decodes
],
query_start_loc_cpu
=
query_start_loc_cpu
[:
num_decodes
+
1
],
query_start_loc_device
=
query_start_loc
[:
num_decodes
+
1
],
num_decode_tokens
=
num_decode_tokens
,
)
if
num_decodes
>
0
:
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
seq_lens_cpu
=
seq_lens_cpu
[:
num_decodes
],
seq_lens_device
=
seq_lens
[:
num_decodes
],
query_start_loc_cpu
=
query_start_loc_cpu
[:
num_decodes
+
1
],
query_start_loc_device
=
query_start_loc
[:
num_decodes
+
1
],
num_decode_tokens
=
num_decode_tokens
,
)
attn_metadata
=
self
.
metadata_cls
(
num_reqs
=
common_attn_metadata
.
num_reqs
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
f6aa3d19
...
...
@@ -131,15 +131,15 @@ class EagleProposer:
with_numpy
=
True
)
# Determine allowed attention backends once during initialization.
self
.
allowed_attn_types
:
Optional
[
tuple
]
=
None
if
current_platform
.
is_rocm
():
rocm_types
=
[
TritonAttentionMetadata
,
FlashAttentionMetadata
]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if
find_spec
(
"vllm.v1.attention.backends.rocm_aiter_fa"
):
from
vllm.v1.attention.backends.rocm_aiter_fa
import
(
AiterFlashAttentionMetadata
)
rocm_types
.
append
(
AiterFlashAttentionMetadata
)
self
.
allowed_attn_types
=
tuple
(
rocm_types
)
#
self.allowed_attn_types: Optional[tuple] = None
#
if current_platform.is_rocm():
#
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
#
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
#
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
#
from vllm.v1.attention.backends.rocm_aiter_fa import (
#
AiterFlashAttentionMetadata)
#
rocm_types.append(AiterFlashAttentionMetadata)
#
self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree.
spec_token_tree
=
self
.
speculative_config
.
speculative_token_tree
...
...
@@ -273,8 +273,7 @@ class EagleProposer:
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
num_tokens
=
num_input_tokens
):
ret_hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
num_input_tokens
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment