Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c552f40
Commit
8c552f40
authored
Nov 08, 2025
by
王敏
Browse files
[fix]解决开启mtp后,在极端情况碰到显存不足时,导致mla中申请的tensor数据错乱问题
parent
2c8a16d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
11 deletions
+42
-11
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+42
-11
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
8c552f40
...
@@ -212,7 +212,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -212,7 +212,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase
,
LinearBase
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.utils
import
cdiv
,
round_down
,
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -399,18 +399,41 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -399,18 +399,41 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self
.
block_table
=
block_table
self
.
block_table
=
block_table
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
self
.
num_scheduled_tokens_np
=
np
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int32
)
self
.
decode_token_num_threshold
=
1
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
self
.
decode_token_num_threshold
=
1
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
speculative_config
=
vllm_config
.
speculative_config
speculative_config
=
vllm_config
.
speculative_config
if
speculative_config
and
speculative_config
.
num_speculative_tokens
>
1
:
if
speculative_config
and
speculative_config
.
num_speculative_tokens
>
1
:
self
.
use_spec_decode
=
True
self
.
use_spec_decode
=
True
self
.
decode_token_num_threshold
=
1
+
speculative_config
.
num_speculative_tokens
self
.
decode_token_num_threshold
=
1
+
speculative_config
.
num_speculative_tokens
self
.
device
=
self
.
runner
.
device
self
.
pin_memory
=
is_pin_memory_available
()
#self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
self
.
num_scheduled_tokens
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
num_scheduled_tokens_cpu
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
num_scheduled_tokens_np
=
self
.
num_scheduled_tokens_cpu
.
numpy
()
self
.
seq_lens_minus
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
*
self
.
decode_token_num_threshold
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
seq_lens_minus_cpu
=
torch
.
zeros
(
scheduler_config
.
max_num_seqs
*
self
.
decode_token_num_threshold
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_minus_np
=
self
.
seq_lens_minus_cpu
.
numpy
()
# support for cudagraph spec docoding
self
.
spec_decode_block_table_tensor
=
None
self
.
spec_decode_seq_lens
=
None
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
@@ -444,6 +467,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -444,6 +467,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_computed_tokens
=
input_batch
.
num_computed_tokens_cpu
[
req_idx
]
num_computed_tokens
=
input_batch
.
num_computed_tokens_cpu
[
req_idx
]
num_prompt_tokens
=
input_batch
.
num_prompt_tokens
[
req_idx
]
num_prompt_tokens
=
input_batch
.
num_prompt_tokens
[
req_idx
]
self
.
num_scheduled_tokens_np
[
i
]
=
num_tokens
self
.
num_scheduled_tokens_np
[
i
]
=
num_tokens
if
num_computed_tokens
<
num_prompt_tokens
or
(
num_tokens
>
self
.
decode_token_num_threshold
):
if
num_computed_tokens
<
num_prompt_tokens
or
(
num_tokens
>
self
.
decode_token_num_threshold
):
prefills
.
append
(
i
)
prefills
.
append
(
i
)
num_prefill_tokens
+=
num_tokens
num_prefill_tokens
+=
num_tokens
...
@@ -646,14 +670,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -646,14 +670,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if
self
.
_num_decodes
>
0
:
if
self
.
_num_decodes
>
0
:
if
self
.
use_spec_decode
and
not
common_attn_metadata
.
spec_layer_decoding
:
if
self
.
use_spec_decode
and
not
common_attn_metadata
.
spec_layer_decoding
:
query_lens
=
self
.
num_scheduled_tokens_np
[:
self
.
_num_decodes
]
query_lens
=
self
.
num_scheduled_tokens_np
[:
self
.
_num_decodes
]
self
.
num_scheduled_tokens
[:
self
.
_num_decodes
].
copy_
(
self
.
num_scheduled_tokens_cpu
[:
self
.
_num_decodes
],
non_blocking
=
True
)
repeats
=
self
.
num_scheduled_tokens
[:
self
.
_num_decodes
]
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
cu_num_blocks
=
np
.
cumsum
(
query_lens
)
virtual_batches
=
cu_num_blocks
[
-
1
]
virtual_batches
=
cu_num_blocks
[
-
1
]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
query_lens
,
query_lens
)
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
rarange
=
np
.
repeat
(
query_lens
,
query_lens
)
-
arange
-
1
rarange
=
np
.
repeat
(
query_lens
,
query_lens
)
-
arange
-
1
repeats
=
torch
.
from_numpy
(
query_lens
).
pin_memory
().
to
(
self
.
seq_lens_minus_np
[:
rarange
.
size
]
=
rarange
block_table_tensor
.
device
,
non_blocking
=
True
).
contiguous
()
self
.
seq_lens_minus
[:
rarange
.
size
].
copy_
(
self
.
seq_lens_minus_cpu
[:
rarange
.
size
],
non_blocking
=
True
)
seq_lens_minus
=
self
.
seq_lens_minus
[:
rarange
.
size
]
if
envs
.
VLLM_ZERO_OVERHEAD
:
if
envs
.
VLLM_ZERO_OVERHEAD
:
decode_block_table_tensor
=
torch
.
empty
((
self
.
_num_decode_tokens
,
block_table_tensor
.
shape
[
1
]),
decode_block_table_tensor
=
torch
.
empty
((
self
.
_num_decode_tokens
,
block_table_tensor
.
shape
[
1
]),
...
@@ -670,8 +702,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -670,8 +702,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
block_table_tensor
[:
self
.
_num_decodes
,
...],
block_table_tensor
[:
self
.
_num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
seq_lens_minus
=
torch
.
from_numpy
(
rarange
).
to
(
torch
.
int32
).
pin_memory
().
to
(
seq_lens
.
device
,
non_blocking
=
True
).
contiguous
()
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
if
self
.
spec_decode_block_table_tensor
is
not
None
:
if
self
.
spec_decode_block_table_tensor
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment