Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d867671
Commit
7d867671
authored
Mar 18, 2025
by
lizhigong
Browse files
fix llm_engine to zero_overhead
parent
08c2298a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
142 additions
and
56 deletions
+142
-56
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+52
-25
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-0
vllm/model_executor/layers/ops/update_input.py
vllm/model_executor/layers/ops/update_input.py
+6
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+38
-21
vllm/sequence.py
vllm/sequence.py
+11
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+29
-4
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-2
No files found.
vllm/engine/llm_engine.py
View file @
7d867671
...
...
@@ -1233,6 +1233,27 @@ class LLMEngine:
return
None
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
sample_out_list
=
output
[
0
].
sampler_out_tenosr
.
cpu
().
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
,
token_id
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
,
sample_out_list
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
sample
.
output_token
=
token_id
[
0
]
seq
.
fix_last_token_id
(
sample
.
output_token
)
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -1386,9 +1407,14 @@ class LLMEngine:
assert
scheduler_outputs
is
not
None
profile
.
ProfRangeAutoPush
(
'execute_model'
)
last_outputs
=
None
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
zero_overhead
:
last_outputs
=
self
.
trans_last_output_tensor
(
self
.
output_recorder
[
self
.
step_switch
])
recode_output
=
self
.
output_recorder
[
self
.
step_switch
]
if
recode_output
is
not
None
:
last_output
=
recode_output
[
0
][
0
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
output_recorder
[
self
.
step_switch
]
=
None
# only use for once
if
not
scheduler_outputs
.
is_empty
():
# Check if we have a cached last_output from the previous iteration.
...
...
@@ -1398,14 +1424,6 @@ class LLMEngine:
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
# print('seq_group_metadata_list', len(seq_group_metadata_list))
# print('scheduler_outputs.blocks_to_swap_in', len(scheduler_outputs.blocks_to_swap_in))
# print('scheduler_outputs.num_lookahead_slots', scheduler_outputs.num_lookahead_slots)
# print('scheduler_outputs.running_queue_size', scheduler_outputs.running_queue_size)
# print('finished_requests_ids', len(finished_requests_ids))
# print('last_sampled_token_ids', last_sampled_token_ids)
# print('self.model_executor', type(self.model_executor))
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
...
@@ -1417,7 +1435,8 @@ class LLMEngine:
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs
=
last_outputs
)
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
...
...
@@ -1437,19 +1456,28 @@ class LLMEngine:
# No outputs in this case
outputs
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
self
.
zero_overhead
:
self
.
output_recorder
[
self
.
step_switch
]
=
outputs
self
.
output_recorder
[
self
.
step_switch
]
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
self
.
step_switch
=
1
-
self
.
step_switch
outputs
=
self
.
output_recorder
[
self
.
step_switch
]
if
outputs
is
None
:
recode_output
=
self
.
output_recorder
[
self
.
step_switch
]
if
recode_output
is
None
:
return
None
#同步上一次的output
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
self
.
output_recorder
[
self
.
step_switch
]
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps.
...
...
@@ -1473,7 +1501,7 @@ class LLMEngine:
if
outputs
and
allow_async_output_proc
:
assert
len
(
outputs
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
if
not
self
.
zero_overhead
:
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
...
...
@@ -1505,7 +1533,6 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
_has_remaining_steps
(
...
...
vllm/entrypoints/llm.py
View file @
7d867671
...
...
@@ -1388,6 +1388,7 @@ class LLM:
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
print
(
'###step_outputs'
,
step_outputs
)
if
step_outputs
is
None
:
continue
for
output
in
step_outputs
:
...
...
vllm/model_executor/layers/ops/update_input.py
View file @
7d867671
...
...
@@ -22,3 +22,8 @@ def _update_input_tokens(
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
last_ids
=
last_ids
.
to
(
'cuda'
)
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/model_executor/layers/sampler.py
View file @
7d867671
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
os
import
warnings
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
...
...
@@ -72,11 +73,12 @@ class SampleResultArgsType:
# Implemented by guanyu
@
dataclass
class
SampleDeviceToDevices
:
num_parent_seq
:
torch
.
Tensor
=
None
seq_id
:
torch
.
Tensor
=
None
random_samples
:
torch
.
Tensor
=
None
sample_idx
:
int
=
None
d2d_data
=
SampleDeviceToDevices
()
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
random_samples
:
torch
.
Tensor
=
None
self
.
zero_overhead
:
bool
=
False
d2d_data
=
SampleDeviceToDevices
()
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
...
...
@@ -144,6 +146,9 @@ class SamplerOutput(
# tree-style cartesian candidates
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_tenosr
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
...
...
@@ -174,7 +179,10 @@ class SamplerOutput(
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
, "
f
"logits=
{
self
.
logits
}
, "
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
)"
)
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
, "
f
"sampler_out_tenosr=
{
self
.
sampler_out_tenosr
}
, "
f
"sampler_out_ids=
{
self
.
sampler_out_ids
}
, "
f
")"
)
class
Sampler
(
nn
.
Module
):
...
...
@@ -206,6 +214,8 @@ class Sampler(nn.Module):
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
d2d_data
.
zero_overhead
=
self
.
zero_overhead
def
_init_sampling_tensors
(
self
,
...
...
@@ -503,7 +513,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
#random_samples = random_samples.cpu()删除,取消gpu->cpu之间的同步
if
not
d2d_data
.
zero_overhead
:
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
:
SampleResultType
=
[]
...
...
@@ -516,20 +526,24 @@ def _random_sample(
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
d2d_data
.
num_parent_seq
=
num_parent_seqs
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
if
d2d_data
.
zero_overhead
:
next_token_ids
=
[
0
]
*
sampling_params
.
n
else
:
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
n
].
tolist
()
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
if
d2d_data
.
zero_overhead
:
next_token_ids
=
[
0
]
*
num_parent_seqs
else
:
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
d2d_data
.
sample_idx
=
sample_idx
return
results
...
...
@@ -707,7 +721,7 @@ def get_pythonized_sample_results(
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
d2d_data
.
random_samples
=
multinomial_samples
[
sampling_type
]
#记录random_samples的数据
d2d_data
.
random_samples
=
multinomial_samples
[
sampling_type
]
#记录random_samples的数据
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
...
...
@@ -744,13 +758,11 @@ def _sample_with_torch(
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
#初始化各种结果存储容器然后按照类型分类
print
(
f
'sampling_metadata.seq_groups的长度:
{
len
(
sampling_metadata
.
seq_groups
)
}
'
)
# 初始化一个tensor张量用于保存seq_id,初始值为-1
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
1
)
-
1
}
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
))
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
#将 i对应的seq_id存储到d2d_data.seq_id中
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
...
@@ -1280,13 +1292,18 @@ def _build_sampler_output(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
if
d2d_data
.
zero_overhead
:
pass
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
logits
=
logits
,
sampler_out_tenosr
=
d2d_data
.
random_samples
,
sampler_out_ids
=
d2d_data
.
seq_id
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
...
...
vllm/sequence.py
View file @
7d867671
...
...
@@ -582,6 +582,11 @@ class Sequence:
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
self
.
data
.
_output_token_ids
[
-
2
]
=
token_id
self
.
data
.
_new_appended_tokens
[
-
2
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
-
2
]
=
token_id
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
...
...
@@ -1403,7 +1408,10 @@ class ExecuteModelRequest(
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs
:
Optional
[
torch
.
Tensor
]
=
None
last_outputs_sample
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_first_multi_step
(
self
)
->
bool
:
...
...
@@ -1455,7 +1463,8 @@ class ExecuteModelRequest(
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
last_outputs
=
self
.
last_outputs
)
last_outputs_sample
=
self
.
last_outputs_sample
,
last_outputs_ids
=
self
.
last_outputs_ids
)
@
dataclass
...
...
vllm/worker/model_runner.py
View file @
7d867671
...
...
@@ -4,6 +4,7 @@ import dataclasses
import
gc
import
inspect
import
itertools
import
os
import
time
import
weakref
from
contextlib
import
contextmanager
...
...
@@ -59,6 +60,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
from
vllm.model_executor.layers.ops.update_input
import
UpdateInputTokens
from
vllm.profiler.prof
import
profile
if
TYPE_CHECKING
:
...
...
@@ -476,6 +479,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window_blocks
*
self
.
block_size
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
last_sample_tensor
=
None
self
.
last_sample_ids
=
None
self
.
req_ids
=
[]
def
SetLastSamperData
(
self
,
last_sample_ids
,
last_sample_tensor
):
self
.
last_sample_tensor
=
last_sample_tensor
self
.
last_sample_ids
=
last_sample_ids
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
...
...
@@ -491,6 +502,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
self
.
attn_metadata_builder
.
prepare
()
self
.
req_ids
.
clear
()
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
...
...
@@ -756,8 +768,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
encoder_seq_len
)
self
.
inter_data_list
.
append
(
inter_data
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
...
...
@@ -898,10 +911,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
cuda_graph_pad_size
:
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
assert
self
.
runner
.
device
is
not
None
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
if
self
.
zero_overhead
and
self
.
last_sample_tensor
is
not
None
:
input_ids
=
torch
.
tensor
(
self
.
req_ids
,
device
=
'cuda'
)
UpdateInputTokens
(
input_tokens_tensor
,
input_ids
,
self
.
last_sample_tensor
,
self
.
last_sample_ids
)
print
(
'####input_tokens_tensor'
,
input_tokens_tensor
)
print
(
'####input_ids'
,
input_ids
)
print
(
'####self.last_sample_tensor'
,
self
.
last_sample_tensor
)
print
(
'####self.last_sample_ids'
,
self
.
last_sample_ids
)
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
self
.
runner
.
device
,
...
...
@@ -1200,7 +1221,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
TModelInputForGPU
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
...
...
@@ -1221,7 +1244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
SetLastSamperData
(
last_outputs_ids
,
last_output_sample
)
return
self
.
builder
.
build
()
# type: ignore
@
contextmanager
...
...
@@ -1616,6 +1639,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
...
...
@@ -1631,7 +1656,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
seq_group_metadata_list
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
if
get_pp_group
().
is_last_rank
:
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
...
...
vllm/worker/model_runner_base.py
View file @
7d867671
...
...
@@ -189,7 +189,6 @@ class ModelRunnerBase(ABC, Generic[T]):
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
last_output
=
None
# Map of request_id -> generator used for seeded random sampling
generators
:
Dict
[
str
,
torch
.
Generator
]
=
{}
...
...
@@ -211,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
T
:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
...
...
vllm/worker/worker_base.py
View file @
7d867671
...
...
@@ -353,12 +353,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
self
.
model_runner
.
last_output
=
execute_model_req
.
last_outputs
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
,
last_outputs_ids
=
execute_model_req
.
last_outputs_ids
,
last_output_sample
=
execute_model_req
.
last_outputs_sample
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment