Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6446f76c
Commit
6446f76c
authored
Nov 27, 2025
by
王敏
Browse files
解决宽松mtp request_ids报错找不到问题
parent
03e822d1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
17 deletions
+34
-17
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+14
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-13
No files found.
vllm/v1/spec_decode/utils.py
View file @
6446f76c
...
@@ -52,7 +52,7 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -52,7 +52,7 @@ class DraftProbs(ABC): # type: ignore[call-arg]
draft_probs
:
torch
.
Tensor
draft_probs
:
torch
.
Tensor
# The request id list.
# The request id list.
_req_ids
:
list
[
str
]
_req_ids
:
list
[
str
]
=
[]
def
__init__
(
self
,
draft_probs
,
req_ids
):
def
__init__
(
self
,
draft_probs
,
req_ids
):
assert
len
(
req_ids
)
==
len
(
draft_probs
)
assert
len
(
req_ids
)
==
len
(
draft_probs
)
...
@@ -64,10 +64,15 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -64,10 +64,15 @@ class DraftProbs(ABC): # type: ignore[call-arg]
tmp_req_ids
:
list
[
str
]):
tmp_req_ids
:
list
[
str
]):
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
diff_req_ids
=
[
item
for
item
in
self
.
_req_ids
if
item
not
in
tmp_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
diff_req_ids
]
self
.
_req_ids
=
diff_req_ids
index_tensor
=
async_tensor_h2d
(
self
.
draft_probs
=
self
.
draft_probs
[
index
]
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
draft_probs
=
torch
.
cat
([
self
.
draft_probs
,
draft_probs
])
self
.
_req_ids
=
diff_req_ids
self
.
_req_ids
.
extend
(
tmp_req_ids
)
self
.
_req_ids
.
extend
(
tmp_req_ids
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
assert
len
(
self
.
_req_ids
)
==
len
(
self
.
draft_probs
)
...
@@ -76,7 +81,12 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -76,7 +81,12 @@ class DraftProbs(ABC): # type: ignore[call-arg]
if
new_req_ids
!=
self
.
_req_ids
:
if
new_req_ids
!=
self
.
_req_ids
:
# Batch contents changed - prune removed sequences.
# Batch contents changed - prune removed sequences.
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
new_req_ids
]
self
.
draft_probs
=
self
.
draft_probs
[
index
]
index_tensor
=
async_tensor_h2d
(
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
self
.
draft_probs
=
self
.
draft_probs
[
index_tensor
]
self
.
_req_ids
=
new_req_ids
self
.
_req_ids
=
new_req_ids
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6446f76c
...
@@ -1531,6 +1531,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1531,6 +1531,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
)
)
sampler_output
.
sampled_token_ids
=
output_token_ids
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
else
:
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
sampler_output
=
self
.
sampler
(
sampler_output
=
self
.
sampler
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
...
@@ -1643,12 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1643,12 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_token_ids
=
spec_result
spec_token_ids
=
spec_result
else
:
else
:
spec_token_ids
,
draft_probs
=
spec_result
spec_token_ids
,
_
=
spec_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
...
@@ -1787,8 +1784,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1787,8 +1784,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
draft_token_ids
=
draft_result
draft_token_ids
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
draft_token_ids
,
draft_probs
=
draft_result
draft_token_ids
,
draft_probs
=
draft_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
,
draft_probs
return
spec_token_ids
,
draft_probs
...
@@ -3357,6 +3360,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3357,6 +3360,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
)
sampler_output
.
sampled_token_ids
=
output_token_ids
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
else
:
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
sampler_output
=
self
.
sampler
(
sampler_output
=
self
.
sampler
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
...
@@ -3438,12 +3443,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3438,12 +3443,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
spec_token_ids
=
spec_result
spec_token_ids
=
spec_result
else
:
else
:
spec_token_ids
,
draft_probs
=
spec_result
spec_token_ids
,
_
=
spec_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
self
.
input_batch
.
req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
self
.
input_batch
.
req_ids
)
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
# No spec decode tokens.
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
...
@@ -3607,7 +3608,13 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3607,7 +3608,13 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
draft_token_ids
,
draft_probs
=
draft_result
draft_token_ids
,
draft_probs
=
draft_result
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
,
draft_probs
return
spec_token_ids
,
draft_probs
return
spec_token_ids
return
spec_token_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment