Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3bbb6e9d
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "80221e1884bd36751e8ae0308acf71f42946a05e"
Commit
3bbb6e9d
authored
Apr 29, 2025
by
lizhigong
Browse files
fix,去掉对base类的耦合
parent
62920e37
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
52 additions
and
99 deletions
+52
-99
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+10
-13
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+0
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+15
-13
vllm/model_executor/layers/update_input.py
vllm/model_executor/layers/update_input.py
+0
-28
vllm/sequence.py
vllm/sequence.py
+1
-9
vllm/spec_decode/target_model_runner.py
vllm/spec_decode/target_model_runner.py
+2
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+22
-25
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+1
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-3
No files found.
vllm/engine/llm_engine.py
View file @
3bbb6e9d
...
@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors
...
@@ -43,7 +43,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_local_guided_decoding_logits_processor
)
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
Sample
rOutput
from
vllm.model_executor.layers.sampler
import
Sample
Recorder
,
SamplerOutput
,
get_last_sampler
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.outputs
import
(
PoolingRequestOutput
,
RequestOutput
,
from
vllm.outputs
import
(
PoolingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
RequestOutputFactory
)
...
@@ -1246,12 +1246,13 @@ class LLMEngine:
...
@@ -1246,12 +1246,13 @@ class LLMEngine:
def
_fix_last_step
(
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
self
,
output
:
List
[
SamplerOutput
],
last_sampler
:
SampleRecorder
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
output
[
0
].
sampler_out_ids
.
tolist
()
sample_out_ids
=
last_sampler
.
seq_ids
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
@@ -1339,12 +1340,9 @@ class LLMEngine:
...
@@ -1339,12 +1340,9 @@ class LLMEngine:
(
seq_group_metadata_list
,
scheduler_outputs
,
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
last_record
is
not
None
:
if
self
.
last_record
is
not
None
:
last_output
=
self
.
last_record
[
0
][
0
]
last_sampler
=
self
.
last_record
[
1
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_d2h
=
last_outputs_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
else
:
...
@@ -1371,9 +1369,7 @@ class LLMEngine:
...
@@ -1371,9 +1369,7 @@ class LLMEngine:
finished_requests_ids
=
finished_requests_ids
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_sampled_token_ids
=
last_sampled_token_ids
)
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
outputs
=
self
.
model_executor
.
execute_model
(
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
execute_model_req
=
execute_model_req
)
...
@@ -1383,7 +1379,8 @@ class LLMEngine:
...
@@ -1383,7 +1379,8 @@ class LLMEngine:
outputs
[
0
],
seq_group_metadata_list
,
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
self
.
last_record
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
last_sampler
=
get_last_sampler
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
...
@@ -1402,12 +1399,12 @@ class LLMEngine:
...
@@ -1402,12 +1399,12 @@ class LLMEngine:
virtual_engine
=
0
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
ctx
.
request_outputs
.
clear
()
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# is_first_step_output is True only when the num_steps of all
...
...
vllm/entrypoints/llm.py
View file @
3bbb6e9d
...
@@ -1412,7 +1412,6 @@ class LLM:
...
@@ -1412,7 +1412,6 @@ class LLM:
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
self
.
llm_engine
.
finish_thread
()
# Sort the outputs by request ID.
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# This is necessary because some requests may be finished earlier than
# its previous requests.
# its previous requests.
...
...
vllm/model_executor/layers/sampler.py
View file @
3bbb6e9d
...
@@ -70,15 +70,17 @@ class SampleResultArgsType:
...
@@ -70,15 +70,17 @@ class SampleResultArgsType:
sampling_metadata
:
SamplingMetadata
sampling_metadata
:
SamplingMetadata
greedy_samples
:
Optional
[
torch
.
Tensor
]
greedy_samples
:
Optional
[
torch
.
Tensor
]
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
# Implemented by guanyu
@
dataclass
class
Sample
DeviceToDevices
:
class
Sample
Recorder
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
seq_id
s
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
self
.
zero_overhead
:
bool
=
False
d2d_data
=
SampleDeviceToDevices
()
last_sampler
=
None
def
get_last_sampler
():
return
last_sampler
# Union of non-deferred (single-step scheduling)
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
# vs deferred (multi-step scheduling)
...
@@ -266,6 +268,8 @@ class Sampler(nn.Module):
...
@@ -266,6 +268,8 @@ class Sampler(nn.Module):
logits: (num_tokens, vocab_size).
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
sampling_metadata: Metadata for sampling.
"""
"""
global
last_sampler
last_sampler
=
SampleRecorder
()
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
...
@@ -763,10 +767,10 @@ def _sample_with_torch(
...
@@ -763,10 +767,10 @@ def _sample_with_torch(
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
}
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
dtype
=
torch
.
int32
)
last_sampler
.
seq_ids
=
[]
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
last_sampler
.
seq_ids
.
append
(
seq_group
.
seq_ids
[
0
]
)
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -801,8 +805,7 @@ def _sample_with_torch(
...
@@ -801,8 +805,7 @@ def _sample_with_torch(
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
dim
=-
1
)
if
d2d_data
.
zero_overhead
:
last_sampler
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
d2d_data
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
# Store sampled tokens in output tensor.
...
@@ -841,9 +844,8 @@ def _sample_with_torch(
...
@@ -841,9 +844,8 @@ def _sample_with_torch(
max_n_in_batch
,
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
seq_groups
=
seq_groups_arg
)
if
d2d_data
.
zero_overhead
:
last_sampler
.
sampled_token_ids_tensor
=
\
d2d_data
.
sampled_token_ids_tensor
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
# Store sampled tokens in output tensor.
...
...
vllm/model_executor/layers/update_input.py
deleted
100644 → 0
View file @
62920e37
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/sequence.py
View file @
3bbb6e9d
...
@@ -1465,12 +1465,6 @@ class ExecuteModelRequest(
...
@@ -1465,12 +1465,6 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_sample
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
# TODO(will) make this be able to handle batches with variable number of
...
@@ -1520,9 +1514,7 @@ class ExecuteModelRequest(
...
@@ -1520,9 +1514,7 @@ class ExecuteModelRequest(
async_callback
=
self
.
async_callback
,
async_callback
=
self
.
async_callback
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
,
tree_position_ids
=
self
.
tree_position_ids
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
)
last_outputs_sample
=
self
.
last_outputs_sample
,
last_outputs_ids
=
self
.
last_outputs_ids
)
@
dataclass
@
dataclass
...
...
vllm/spec_decode/target_model_runner.py
View file @
3bbb6e9d
...
@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase):
...
@@ -30,13 +30,11 @@ class TargetModelRunner(ModelRunnerWrapperBase):
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelRunnerInputBase
:
)
->
ModelRunnerInputBase
:
model_input
:
ModelRunnerInputBase
=
\
model_input
:
ModelRunnerInputBase
=
\
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
# If token log probabilities is disabled then skip generating sampler
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# as needed. If log probabilities is enabled then synchronize all the
...
...
vllm/worker/model_runner.py
View file @
3bbb6e9d
...
@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest
...
@@ -36,7 +36,7 @@ from vllm.lora.request import LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor
import
SamplingMetadata
,
SamplingMetadataCache
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_last_sampler
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
...
@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import (
...
@@ -61,8 +61,6 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.model_executor.layers.update_input
import
UpdateInputTokens
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -479,14 +477,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
last_sample_tensor
=
None
self
.
last_sample_ids
=
None
self
.
req_ids
=
[]
self
.
req_ids
=
[]
def
SetLastSamperData
(
self
,
last_sample_ids
,
last_sample_tensor
):
self
.
last_sample_tensor
=
last_sample_tensor
self
.
last_sample_ids
=
last_sample_ids
def
prepare
(
self
,
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
finished_requests_ids
=
finished_requests_ids
self
.
finished_requests_ids
=
finished_requests_ids
...
@@ -915,14 +907,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -915,14 +907,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
if
self
.
zero_overhead
and
self
.
last_sample_tensor
is
not
None
:
if
self
.
zero_overhead
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
last_sampler
=
get_last_sampler
()
self
.
runner
.
device
,
update_indices
=
[]
self
.
runner
.
pin_memory
)
select_indices
=
[]
last_ids
=
async_tensor_h2d
(
self
.
last_sample_ids
.
tolist
(),
torch
.
long
,
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
self
.
runner
.
device
,
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
self
.
runner
.
pin_memory
)
if
seq_id
==
seq_id_
:
UpdateInputTokens
(
input_tokens_tensor
,
input_ids
,
self
.
last_sample_tensor
,
last_ids
)
select_indices
.
append
(
j
)
update_indices
.
append
(
i
)
break
if
len
(
select_indices
)
>
0
:
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
update_indices
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
input_tokens_tensor
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
...
@@ -1225,9 +1227,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1225,9 +1227,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
_prepare_model_input_tensors
(
def
_prepare_model_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
TModelInputForGPU
:
)
->
TModelInputForGPU
:
"""Helper method to prepare the model input based on a given sequence
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
group. Prepares metadata needed for the base model forward pass but not
...
@@ -1248,7 +1248,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1248,7 +1248,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
SetLastSamperData
(
last_outputs_ids
,
last_output_sample
)
return
self
.
builder
.
build
()
# type: ignore
return
self
.
builder
.
build
()
# type: ignore
@
contextmanager
@
contextmanager
...
@@ -1642,9 +1641,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1642,9 +1641,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
metadata for the sampling step.
...
@@ -1660,7 +1657,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1660,7 +1657,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
If cuda graph is required, this API automatically pads inputs.
"""
"""
model_input
=
self
.
_prepare_model_input_tensors
(
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
seq_group_metadata_list
,
finished_requests_ids
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
# Sampling metadata is only required for the final pp group
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
generators
=
self
.
get_generators
(
finished_requests_ids
)
...
...
vllm/worker/model_runner_base.py
View file @
3bbb6e9d
...
@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -209,9 +209,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
T
:
)
->
T
:
"""
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
Prepare the inputs to ModelRunnerBase.execute_model from an execution
...
...
vllm/worker/worker_base.py
View file @
3bbb6e9d
...
@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
...
@@ -374,9 +374,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self
.
model_runner
.
prepare_model_input
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
,
execute_model_req
.
finished_requests_ids
))
last_outputs_ids
=
execute_model_req
.
last_outputs_ids
,
last_output_sample
=
execute_model_req
.
last_outputs_sample
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
execute_model_req
.
tree_attn_masks
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment