Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30b93f8e
Commit
30b93f8e
authored
Dec 04, 2025
by
王敏
Browse files
[fix]解决mtp保存draft prob在例如pd分离场景下的OOM问题
parent
8364249c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
1 deletion
+18
-1
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+18
-1
No files found.
vllm/v1/spec_decode/utils.py
View file @
30b93f8e
...
@@ -54,14 +54,22 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -54,14 +54,22 @@ class DraftProbs(ABC): # type: ignore[call-arg]
# The request id list.
# The request id list.
_req_ids
:
list
[
str
]
=
[]
_req_ids
:
list
[
str
]
=
[]
count
=
0
req_id_to_count
:
dict
[
str
,
int
]
=
{}
prune_threshould
=
100
def
__init__
(
self
,
draft_probs
,
req_ids
):
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
assert
len
(
req_ids
)
==
len
(
draft_probs
)
self
.
draft_probs
=
draft_probs
self
.
draft_probs
=
draft_probs
self
.
_req_ids
=
req_ids
self
.
_req_ids
=
req_ids
for
req_id
in
req_ids
:
self
.
req_id_to_count
[
req_id
]
=
self
.
count
def
update
(
self
,
def
update
(
self
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
tmp_req_ids
:
list
[
str
]):
tmp_req_ids
:
list
[
str
]):
self
.
count
+=
1
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
index_tensor
=
async_tensor_h2d
(
index_tensor
=
async_tensor_h2d
(
...
@@ -71,12 +79,21 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -71,12 +79,21 @@ class DraftProbs(ABC): # type: ignore[call-arg]
pin_memory
=
True
)
pin_memory
=
True
)
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
=
diff_req_ids
self
.
_req_ids
=
diff_req_ids
self
.
_req_ids
.
extend
(
tmp_req_ids
)
self
.
_req_ids
.
extend
(
tmp_req_ids
)
for
req_id
in
tmp_req_ids
:
self
.
req_id_to_count
[
req_id
]
=
self
.
count
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
def
prune
(
self
,
req_ids
:
list
[
str
]):
def
prune
(
self
,
req_ids
:
list
[
str
]):
if
self
.
count
%
self
.
prune_threshould
==
0
:
for
req_id
,
last_count
in
self
.
req_id_to_count
.
items
():
if
self
.
count
-
last_count
>=
self
.
prune_threshould
:
req_ids
.
append
(
req_id
)
self
.
req_id_to_count
=
{
k
:
v
for
k
,
v
in
self
.
req_id_to_count
.
items
()
if
k
not
in
req_ids
}
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
new_req_ids
=
[
req_id
for
req_id
in
self
.
_req_ids
if
req_id
not
in
req_ids
]
if
new_req_ids
!=
self
.
_req_ids
:
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
# Batch contents changed - prune removed sequences.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment