Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35017fdf
Commit
35017fdf
authored
Feb 09, 2026
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev-wm' into 'v0.9.2-dev'
[fix]解决宽松mtp引入的同步问题 See merge request dcutoolkit/deeplearing/vllm!417
parents
d73be361
b70256d7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
34 deletions
+50
-34
vllm/v1/utils.py
vllm/v1/utils.py
+6
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+37
-24
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-9
No files found.
vllm/v1/utils.py
View file @
35017fdf
...
...
@@ -321,7 +321,7 @@ def bind_kv_cache(
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
torch
.
Tensor
:
length
:
int
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
...
...
@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
if
repeat_counts
is
not
None
:
from_tensor_tmp
=
torch
.
repeat_interleave
(
from_tensor
[:
length
],
repeat_counts
,
dim
=
0
)
length
=
torch
.
sum
(
repeat_counts
).
item
()
from_tensor
[:
length
].
copy_
(
from_tensor_tmp
)
return
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
35017fdf
...
...
@@ -9,6 +9,7 @@ import numpy as np
import
torch
from
vllm
import
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -79,6 +80,10 @@ class InputBatch:
is_spec_decode
:
bool
=
False
,
logits_processing_needs_token_ids
:
bool
=
False
,
):
ori_max_num_reqs
=
max_num_reqs
if
is_spec_decode
and
envs
.
VLLM_REJECT_SAMPLE_OPT
:
vllm_config
=
get_current_vllm_config
()
max_num_reqs
=
max_num_reqs
*
(
1
+
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
is_spec_decode
=
is_spec_decode
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
...
...
@@ -97,7 +102,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self
.
token_ids_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
max_model_len
),
(
ori_
max_num_reqs
,
max_model_len
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
False
,
...
...
@@ -651,36 +656,44 @@ class InputBatch:
or
repeat_counts
is
not
None
or
self
.
_sampling_metadata_is_expanded
)
if
needs_rebuild
:
if
repeat_counts
is
None
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
else
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata_expanded
(
repeat_counts
)
# if repeat_counts is None:
# self.sampling_metadata = self._make_sampling_metadata()
# else:
# self.sampling_metadata = self._make_sampling_metadata_expanded(
# repeat_counts)
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
(
repeat_counts
)
self
.
_sampling_metadata_is_expanded
=
repeat_counts
is
not
None
# Expanded metadata is built on demand; do not cache a copy here.
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
def
_make_sampling_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
if
not
self
.
all_greedy
:
temperature
=
copy_slice
(
self
.
temperature_cpu_tensor
,
self
.
temperature
,
num_reqs
)
self
.
temperature
,
num_reqs
,
repeat_counts
)
else
:
temperature
=
None
if
not
self
.
no_top_p
:
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
)
top_p
=
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
,
repeat_counts
)
if
not
self
.
no_top_k
:
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
)
top_k
=
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
,
repeat_counts
)
frequency_penalties
=
None
presence_penalties
=
None
repetition_penalties
=
None
if
not
self
.
no_penalties
:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
self
.
frequency_penalties
,
num_reqs
)
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties
,
num_reqs
)
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
)
frequency_penalties
=
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
self
.
frequency_penalties
,
num_reqs
,
repeat_counts
)
presence_penalties
=
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties
,
num_reqs
,
repeat_counts
)
repetition_penalties
=
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
,
repeat_counts
)
needs_prompt_token_ids
=
(
not
self
.
no_penalties
or
(
self
.
num_reqs
>
0
...
...
@@ -697,9 +710,9 @@ class InputBatch:
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
copy_slice
(
self
.
allowed_token_ids_mask_cpu_tensor
,
self
.
allowed_token_ids_mask
,
num_reqs
)
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask
[:
num_reqs
]
allowed_token_ids_mask
=
copy_slice
(
self
.
allowed_token_ids_mask_cpu_tensor
,
self
.
allowed_token_ids_mask
,
num_reqs
,
repeat_counts
)
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
...
...
@@ -714,14 +727,14 @@ class InputBatch:
temperature
=
temperature
,
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
top_p
=
None
if
self
.
no_top_p
else
self
.
top_p
[:
num_reqs
]
,
top_k
=
None
if
self
.
no_top_k
else
self
.
top_k
[:
num_reqs
]
,
top_p
=
None
if
self
.
no_top_p
else
top_p
,
top_k
=
None
if
self
.
no_top_k
else
top_k
,
generators
=
self
.
generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
]
,
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
]
,
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
]
,
frequency_penalties
=
None
if
self
.
no_penalties
else
frequency_penalties
,
presence_penalties
=
None
if
self
.
no_penalties
else
presence_penalties
,
repetition_penalties
=
None
if
self
.
no_penalties
else
repetition_penalties
,
output_token_ids
=
cast
(
list
[
list
[
int
]],
self
.
req_output_token_ids
),
no_penalties
=
self
.
no_penalties
,
allowed_token_ids_mask
=
allowed_token_ids_mask
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
35017fdf
...
...
@@ -586,17 +586,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
repeat_counts
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
\
scheduler_output
.
scheduled_spec_decode_tokens
:
num_reqs
=
self
.
input_batch
.
num_reqs
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
repeat_counts
=
[
1
]
*
self
.
input_batch
.
num_reqs
#num_reqs = self.input_batch.num_reqs
#num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
if
req_idx
is
not
None
:
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
repeat_counts
=
torch
.
from_numpy
(
num_draft_tokens
).
add_
(
1
)
repeat_counts
[
req_idx
]
+=
len
(
draft_token_ids
)
repeat_counts
=
torch
.
tensor
(
repeat_counts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
input_batch
.
refresh_metadata
(
repeat_counts
)
def
_get_cumsum_and_arange
(
...
...
@@ -1565,8 +1567,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
)
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
@@ -3431,8 +3431,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output
.
sampled_token_ids
=
output_token_ids
else
:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment