Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78e20661
"vscode:/vscode.git/clone" did not exist on "aff404571b0d5aba342c46fdf5d7f8a251da9383"
Commit
78e20661
authored
Feb 08, 2026
by
王敏
Browse files
[feat]宽松mtp支持temp,top-p等参数设置
parent
e807ec39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
35 deletions
+53
-35
vllm/v1/utils.py
vllm/v1/utils.py
+7
-3
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+35
-31
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-1
No files found.
vllm/v1/utils.py
View file @
78e20661
...
...
@@ -320,9 +320,8 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree
(
pid
)
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
)
->
torch
.
Tensor
:
def
copy_slice
(
from_tensor
:
torch
.
Tensor
,
to_tensor
:
torch
.
Tensor
,
length
:
int
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
...
...
@@ -331,6 +330,11 @@ def copy_slice(
Returns the sliced target tensor.
"""
if
repeat_counts
is
not
None
:
from_tensor_tmp
=
torch
.
repeat_interleave
(
from_tensor
[:
length
],
repeat_counts
,
dim
=
0
)
length
=
torch
.
sum
(
repeat_counts
).
item
()
from_tensor
[:
length
].
copy_
(
from_tensor_tmp
)
return
to_tensor
[:
length
].
copy_
(
from_tensor
[:
length
],
non_blocking
=
True
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
78e20661
...
...
@@ -3,11 +3,13 @@
# Datastructures defining a GPU input batch
from
dataclasses
import
dataclass
from
typing
import
cast
from
typing
import
Optional
,
cast
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -96,6 +98,11 @@ class InputBatch:
is_pooling_model
:
bool
=
False
,
cp_kv_cache_interleave_size
:
int
=
1
,
):
ori_max_num_reqs
=
max_num_reqs
if
is_spec_decode
and
envs
.
VLLM_REJECT_SAMPLE_OPT
:
vllm_config
=
get_current_vllm_config
()
max_num_reqs
=
max_num_reqs
*
(
1
+
vllm_config
.
speculative_config
.
num_speculative_tokens
)
self
.
is_pooling_model
=
is_pooling_model
self
.
is_spec_decode
=
is_spec_decode
self
.
max_num_reqs
=
max_num_reqs
...
...
@@ -113,7 +120,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self
.
token_ids_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
max_model_len
),
(
ori_
max_num_reqs
,
max_model_len
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
False
,
...
...
@@ -753,13 +760,13 @@ class InputBatch:
del
self
.
req_output_token_ids
[
num_reqs
:]
del
self
.
spec_token_ids
[
num_reqs
:]
def
refresh_metadata
(
self
):
def
refresh_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Apply any batch updates to sampling metadata."""
if
self
.
is_pooling_model
:
batch_changed
=
self
.
batch_update_builder
.
reset
()
if
batch_changed
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
(
repeat_counts
)
return
# For non-pooling models - generate and apply logitsprocs update;
...
...
@@ -769,36 +776,36 @@ class InputBatch:
for
logit_proc
in
self
.
logitsprocs
.
all
:
logit_proc
.
update_state
(
batch_update
)
if
batch_update
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
()
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
(
repeat_counts
)
def
_make_sampling_metadata
(
self
)
->
SamplingMetadata
:
def
_make_sampling_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplingMetadata
:
num_reqs
=
self
.
num_reqs
if
not
self
.
all_greedy
:
temperature
=
copy_slice
(
self
.
temperature_cpu_tensor
,
self
.
temperature
,
num_reqs
self
.
temperature_cpu_tensor
,
self
.
temperature
,
num_reqs
,
repeat_counts
)
else
:
temperature
=
None
if
not
self
.
no_top_p
:
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
)
top_p
=
copy_slice
(
self
.
top_p_cpu_tensor
,
self
.
top_p
,
num_reqs
,
repeat_counts
)
if
not
self
.
no_top_k
:
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
)
top_k
=
copy_slice
(
self
.
top_k_cpu_tensor
,
self
.
top_k
,
num_reqs
,
repeat_counts
)
if
not
self
.
no_penalties
:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
self
.
frequency_penalties
,
num_reqs
)
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties
,
num_reqs
)
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
,
)
frequency_penalties
=
copy_slice
(
self
.
frequency_penalties_cpu_tensor
,
self
.
frequency_penalties
,
num_reqs
,
repeat_counts
)
presence_penalties
=
copy_slice
(
self
.
presence_penalties_cpu_tensor
,
self
.
presence_penalties
,
num_reqs
,
repeat_counts
)
repetition_penalties
=
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
,
repeat_counts
)
needs_prompt_token_ids
=
(
not
self
.
no_penalties
...
...
@@ -828,25 +835,22 @@ class InputBatch:
allowed_token_ids_mask
:
torch
.
Tensor
|
None
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
copy_slice
(
self
.
allowed_token_ids_mask_cpu_tensor
,
self
.
allowed_token_ids_mask
,
num_reqs
,
)
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask
[:
num_reqs
]
allowed_token_ids_mask
=
copy_slice
(
self
.
allowed_token_ids_mask_cpu_tensor
,
self
.
allowed_token_ids_mask
,
num_reqs
,
repeat_counts
)
return
SamplingMetadata
(
temperature
=
temperature
,
all_greedy
=
self
.
all_greedy
,
all_random
=
self
.
all_random
,
top_p
=
None
if
self
.
no_top_p
else
self
.
top_p
[:
num_reqs
]
,
top_k
=
None
if
self
.
no_top_k
else
self
.
top_k
[:
num_reqs
]
,
top_p
=
None
if
self
.
no_top_p
else
top_p
,
top_k
=
None
if
self
.
no_top_k
else
top_k
,
generators
=
self
.
generators
,
max_num_logprobs
=
self
.
max_num_logprobs
,
prompt_token_ids
=
prompt_token_ids
,
frequency_penalties
=
self
.
frequency_penalties
[:
num_reqs
]
,
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
]
,
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
]
,
frequency_penalties
=
None
if
self
.
no_penalties
else
frequency_penalties
,
presence_penalties
=
None
if
self
.
no_penalties
else
presence_penalties
,
repetition_penalties
=
None
if
self
.
no_penalties
else
repetition_penalties
,
output_token_ids
=
output_token_ids
,
spec_token_ids
=
self
.
spec_token_ids
,
no_penalties
=
self
.
no_penalties
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
78e20661
...
...
@@ -1102,7 +1102,17 @@ class GPUModelRunner(
# Allow attention backend to reorder the batch, potentially
self
.
_may_reorder_batch
(
scheduler_output
)
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
repeat_counts
=
None
if
envs
.
VLLM_REJECT_SAMPLE_OPT
and
\
scheduler_output
.
scheduled_spec_decode_tokens
:
repeat_counts
=
[
1
]
*
self
.
input_batch
.
num_reqs
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
.
get
(
req_id
)
if
req_idx
is
not
None
:
repeat_counts
[
req_idx
]
+=
len
(
draft_token_ids
)
repeat_counts
=
torch
.
tensor
(
repeat_counts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
input_batch
.
refresh_metadata
(
repeat_counts
)
def
_update_states_after_model_execute
(
self
,
output_token_ids
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment