Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a935929
Commit
2a935929
authored
May 16, 2025
by
lizhigong
Browse files
修复zero-overhead首字正确性问题,zero-overhead不使用默认流调整,增加two-batch-overlap功能
parent
cf1d8464
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
649 additions
and
83 deletions
+649
-83
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+3
-0
vllm/forward_context.py
vllm/forward_context.py
+17
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+7
-1
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+7
-1
vllm/two_batch_overlap/forward_context.py
vllm/two_batch_overlap/forward_context.py
+35
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+465
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+26
-11
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-1
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+75
-68
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+10
-0
No files found.
vllm/engine/llm_engine.py
View file @
2a935929
...
...
@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname
,
weak_bind
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
...
@@ -413,6 +414,7 @@ class LLMEngine:
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self
.
_skip_scheduling_next_step
=
False
profile
.
StartTracer
()
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
vllm/executor/executor_base.py
View file @
2a935929
...
...
@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
...
...
@@ -143,6 +144,7 @@ class ExecutorBase(ABC):
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
"""Releases parallel workers from model loop."""
finish_two_batch_overlap
()
return
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase):
return
driver_outputs
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
finish_two_batch_overlap
()
if
self
.
parallel_worker_tasks
is
None
:
return
...
...
vllm/forward_context.py
View file @
2a935929
# SPDX-License-Identifier: Apache-2.0
import
os
import
time
from
collections
import
defaultdict
from
contextlib
import
contextmanager
...
...
@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.logger
import
init_logger
from
vllm.two_batch_overlap.forward_context
import
get_tbo_forward_context
,
set_tbo_forward_context
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
@@ -28,6 +30,9 @@ forward_start_time: float = 0
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
enable_tbo
=
os
.
environ
.
get
(
'VLLM_ENABLE_TBO'
)
==
'1'
def
is_enable_tbo
():
return
enable_tbo
@
dataclass
class
DPMetadata
:
...
...
@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None
def
get_forward_context
()
->
ForwardContext
:
if
is_enable_tbo
():
forward_context
=
get_tbo_forward_context
()
"""Get the current forward context."""
assert
forward_context
is
not
None
,
(
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return
forward_context
"""Get the current forward context."""
assert
_forward_context
is
not
None
,
(
"Forward context is not set. "
...
...
@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any,
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase_V1
)
kv_connector
.
start_load_kv
(
_forward_context
)
if
is_enable_tbo
():
set_tbo_forward_context
(
_forward_context
)
try
:
yield
finally
:
...
...
@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector
.
wait_for_save
()
_forward_context
=
prev_context
if
is_enable_tbo
():
set_tbo_forward_context
(
_forward_context
)
vllm/model_executor/layers/linear.py
View file @
2a935929
...
...
@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase):
})
else
:
self
.
register_parameter
(
"bias"
,
None
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
@@ -1307,7 +1310,10 @@ class RowParallelLinear(LinearBase):
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
if
self
.
enable_tbo
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
2a935929
...
...
@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
...
...
@@ -434,7 +437,10 @@ class VocabParallelEmbedding(torch.nn.Module):
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
# Reduce across all the model parallel GPUs.
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
if
self
.
enable_tbo
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
def
extra_repr
(
self
)
->
str
:
...
...
vllm/two_batch_overlap/forward_context.py
0 → 100644
View file @
2a935929
import
threading
_forward_context_left
=
None
_forward_context_right
=
None
_left_tid
=
0
_right_tid
=
0
def
init_tbo_forward_context
(
left_flag
,
tid
):
global
_left_tid
global
_right_tid
if
left_flag
:
_left_tid
=
tid
else
:
_right_tid
=
tid
def
set_tbo_forward_context
(
_forward_context
):
global
_forward_context_left
global
_forward_context_right
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
_forward_context_left
=
_forward_context
else
:
_forward_context_right
=
_forward_context
def
get_tbo_forward_context
():
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
return
_forward_context_left
else
:
return
_forward_context_right
vllm/two_batch_overlap/two_batch_overlap.py
0 → 100644
View file @
2a935929
import
os
import
queue
import
threading
import
torch
from
vllm.attention.backends.rocm_flash_attn
import
ROCmFlashAttentionMetadata
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.utils
import
async_tensor_h2d
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
enable_tbo
=
os
.
environ
.
get
(
'VLLM_ENABLE_TBO'
)
==
'1'
enable_tbo_decode
=
os
.
environ
.
get
(
'VLLM_TBO_DECODE'
)
==
'1'
tbo_one_stream
=
os
.
environ
.
get
(
'VLLM_TBO_ONE_STREAM'
)
==
'1'
logger
=
init_logger
(
__name__
)
def
is_enable_tbo
():
return
enable_tbo
class
TwoBatchOverlap
():
def
__init__
(
self
):
self
.
model_input_left_queue
=
queue
.
Queue
()
self
.
model_input_right_queue
=
queue
.
Queue
()
self
.
states_left_queue
=
queue
.
Queue
()
self
.
states_right_queue
=
queue
.
Queue
()
self
.
all_reduce_queue
=
queue
.
Queue
()
self
.
all_reduce_out
=
queue
.
Queue
()
self
.
left_thread
=
None
self
.
right_thread
=
None
self
.
left_tid
=
0
self
.
right_tid
=
0
self
.
sem_left
=
threading
.
Semaphore
(
0
)
self
.
sem_right
=
threading
.
Semaphore
(
0
)
self
.
left_first
=
False
self
.
tbo_running
=
False
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
event_left_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_c2t
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_left_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
self
.
model_input_left_queue
.
empty
()
self
.
model_input_right_queue
.
empty
()
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_left_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
left_thread
.
start
()
self
.
right_thread
.
start
()
def
finish_thread
(
self
):
if
self
.
left_thread
!=
None
:
self
.
model_input_left_queue
.
put
(
None
)
self
.
left_thread
.
join
()
self
.
left_thread
=
None
if
self
.
right_thread
!=
None
:
self
.
model_input_right_queue
.
put
(
None
)
self
.
right_thread
.
join
()
self
.
right_thread
=
None
logger
.
info
(
'tbo:finish threads'
)
@
torch
.
inference_mode
()
def
thread_two_batch_overlap
(
self
,
queue
):
is_left_thread
=
False
if
queue
==
self
.
model_input_left_queue
:
self
.
left_tid
=
threading
.
get_ident
()
is_left_thread
=
True
logger
.
info
(
'tbo:new thread %d'
,
self
.
left_tid
)
init_tbo_forward_context
(
True
,
self
.
left_tid
)
else
:
self
.
right_tid
=
threading
.
get_ident
()
logger
.
info
(
'tbo:new thread %d'
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
while
True
:
model_input
=
queue
.
get
()
if
model_input
==
None
:
break
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
False
)
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
self
.
virtual_engine
):
hidden_or_intermediate_states
=
self
.
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
self
.
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
self
.
multi_modal_kwargs
,
device
=
self
.
self_device
),
**
self
.
seqlen_agnostic_kwargs
,
**
self
.
model_kwargs
,
)
profile
.
ProfRangePush
(
'end'
)
if
is_left_thread
:
self
.
sem_right
.
release
()
self
.
states_left_queue
.
put
(
hidden_or_intermediate_states
)
else
:
self
.
all_reduce_queue
.
put
(
None
)
self
.
states_right_queue
.
put
(
hidden_or_intermediate_states
)
def
tbo_thread_synchronize
(
self
,
recode_flag
=
True
):
tid
=
threading
.
get_ident
()
if
tid
==
self
.
left_tid
:
if
recode_flag
and
not
tbo_one_stream
:
print
(
'###left_c2t_recorded'
)
self
.
event_left_c2t
.
record
()
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
profile
.
ProfRangePop
()
self
.
sem_left
.
acquire
()
profile
.
ProfRangePush
(
'left'
)
self
.
left_first
=
False
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
if
recode_flag
and
not
tbo_one_stream
:
print
(
'###right_c2t_recorded'
)
self
.
event_right_c2t
.
record
()
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
profile
.
ProfRangePush
(
'right'
)
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
model_input_left
,
model_input_right
,
vllm_config
,
virtual_engine
,
model_executable
,
intermediate_tensors
,
multi_modal_kwargs
,
self_device
,
seqlen_agnostic_kwargs
,
model_kwargs
):
if
self
.
left_thread
==
None
:
self
.
init_tbo_thread
()
self
.
vllm_config
=
vllm_config
self
.
virtual_engine
=
virtual_engine
self
.
model_executable
=
model_executable
self
.
intermediate_tensors
=
intermediate_tensors
self
.
multi_modal_kwargs
=
multi_modal_kwargs
self
.
self_device
=
self_device
self
.
seqlen_agnostic_kwargs
=
seqlen_agnostic_kwargs
self
.
model_kwargs
=
model_kwargs
self
.
model_input_left_queue
.
put
(
model_input_left
)
self
.
model_input_right_queue
.
put
(
model_input_right
)
def
get_model_output
(
self
):
states_left
=
self
.
states_left_queue
.
get
()
states_right
=
self
.
states_right_queue
.
get
()
return
states_left
,
states_right
def
all_reduce
(
self
):
while
True
:
obj
=
self
.
all_reduce_queue
.
get
()
if
obj
==
None
:
break
buf
,
event_c2t
,
event_t2c
=
obj
#print('###buf', buf[0,0:5])
if
tbo_one_stream
:
output
=
tensor_model_parallel_all_reduce
(
buf
)
else
:
with
torch
.
cuda
.
stream
(
self
.
stream
):
print
(
'###stream.wait_event event_c2t before all_reduce'
)
self
.
stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
buf
)
print
(
'###event_t2c recorded'
)
event_t2c
.
record
()
#print('###print', output[0,0:5])
self
.
all_reduce_out
.
put
(
output
)
tbo_obj
=
None
def
init_two_batch_overlap
():
if
enable_tbo
:
global
tbo_obj
if
tbo_obj
==
None
:
tbo_obj
=
TwoBatchOverlap
()
def
finish_two_batch_overlap
():
global
tbo_obj
if
tbo_obj
!=
None
:
tbo_obj
.
finish_thread
()
tbo_obj
=
None
def
tbo_all_reduce
(
obj
):
if
enable_tbo
and
tbo_obj
!=
None
and
tbo_obj
.
tbo_running
:
event_c2t
,
event_t2c
=
tbo_obj
.
tbo_thread_synchronize
()
tbo_obj
.
all_reduce_queue
.
put
([
obj
,
event_c2t
,
event_t2c
])
output
=
tbo_obj
.
all_reduce_out
.
get
()
if
not
tbo_one_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
print
(
'###current_stream wait event event_t2c'
)
current_stream
.
wait_event
(
event_t2c
)
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
def
cumsum
(
lst
):
cum_lst
=
[
0
]
sum
=
0
for
i
in
range
(
0
,
len
(
lst
)):
sum
=
sum
+
lst
[
i
]
cum_lst
.
append
(
sum
)
return
cum_lst
def
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
):
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
batch_size_split
=
[
batch_size_left
,
batch_size_right
]
split_input_tokens
=
torch
.
split
(
model_input
.
input_tokens
,
query_tokens_split
,
dim
=
0
)
split_input_positions
=
torch
.
split
(
model_input
.
input_positions
,
query_tokens_split
,
dim
=
0
)
seq_lens_left
=
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
]
seq_lens_right
=
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:]
query_lens_left
=
model_input
.
query_lens
[
0
:
batch_size_left
]
query_lens_right
=
model_input
.
query_lens
[
batch_size_left
:]
split_seq_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
seq_lens_tensor
,
batch_size_split
,
dim
=
0
)
split_block_tables
=
torch
.
split
(
model_input
.
attn_metadata
.
block_tables
,
batch_size_split
,
dim
=
0
)
num_prefills_left
=
0
num_prefills_right
=
0
num_prefill_tokens_left
=
0
num_prefill_tokens_right
=
0
num_decode_tokens_left
=
0
num_decode_tokens_right
=
0
max_prefill_seq_len_left
=
0
max_prefill_seq_len_right
=
0
max_decode_seq_len_left
=
0
max_decode_seq_len_right
=
0
max_decode_query_len_left
=
None
max_decode_query_len_right
=
None
encoder_seq_lens_left
=
None
encoder_seq_lens_right
=
None
encoder_seq_lens_tensor_left
=
None
encoder_seq_lens_tensor_right
=
None
max_encoder_seq_len_left
=
None
max_encoder_seq_len_right
=
None
num_encoder_tokens_left
=
None
num_encoder_tokens_right
=
None
cross_slot_mapping_left
=
None
cross_slot_mapping_right
=
None
cross_block_tables_left
=
None
cross_block_tables_right
=
None
if
model_input
.
is_prompt
:
num_prefills_left
=
batch_size_left
num_prefills_right
=
batch_size_right
num_prefill_tokens_left
=
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
])
num_prefill_tokens_right
=
sum
(
model_input
.
query_lens
[
batch_size_left
:])
max_prefill_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_prefill_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
else
:
num_decode_tokens_left
=
batch_size_left
num_decode_tokens_right
=
batch_size_right
max_decode_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_decode_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
split_slot_mapping
=
torch
.
split
(
model_input
.
attn_metadata
.
slot_mapping
,
query_tokens_split
,
dim
=
0
)
max_query_len_left
=
max
(
model_input
.
query_lens
[
0
:
batch_size_left
])
max_query_len_right
=
max
(
model_input
.
query_lens
[
batch_size_left
:])
zero_tensor
=
torch
.
tensor
([
0
],
device
=
self_device
,
dtype
=
torch
.
int32
)
query_start_loc_left_list
=
cumsum
(
query_lens_left
)
query_start_loc_right_list
=
cumsum
(
query_lens_right
)
query_start_loc_left
=
async_tensor_h2d
(
query_start_loc_left_list
,
torch
.
int32
,
self_device
,
True
)
query_start_loc_right
=
async_tensor_h2d
(
query_start_loc_right_list
,
torch
.
int32
,
self_device
,
True
)
seq_start_loc_left
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
seq_start_loc_right
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
split_context_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
context_lens_tensor
,
batch_size_split
,
dim
=
0
)
block_tables_list_left
=
model_input
.
attn_metadata
.
block_tables_list
[
0
:
batch_size_left
]
block_tables_list_right
=
model_input
.
attn_metadata
.
block_tables_list
[
batch_size_left
:]
request_ids_to_seq_ids_left
=
{}
request_ids_to_seq_ids_right
=
{}
counter
=
0
for
key
,
value
in
model_input
.
request_ids_to_seq_ids
.
items
():
if
counter
<
batch_size_left
:
request_ids_to_seq_ids_left
[
key
]
=
value
else
:
request_ids_to_seq_ids_right
[
key
]
=
value
counter
+=
1
seq_groups_left
=
None
seq_groups_right
=
None
if
model_input
.
sampling_metadata
.
seq_groups
is
not
None
:
seq_groups_left
=
model_input
.
sampling_metadata
.
seq_groups
[
0
:
batch_size_left
]
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
attn_metadata_left
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
0
],
max_decode_seq_len
=
max_decode_seq_len_left
,
block_tables
=
split_block_tables
[
0
],
num_prefills
=
num_prefills_left
,
num_prefill_tokens
=
num_prefill_tokens_left
,
num_decode_tokens
=
num_decode_tokens_left
,
slot_mapping
=
split_slot_mapping
[
0
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_left
,
max_prefill_seq_len
=
max_prefill_seq_len_left
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_left
,
query_start_loc
=
query_start_loc_left
,
seq_start_loc
=
seq_start_loc_left
,
context_lens_tensor
=
split_context_lens_tensor
[
0
],
max_decode_query_len
=
max_decode_query_len_left
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_left
,
encoder_seq_lens
=
encoder_seq_lens_left
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_left
,
max_encoder_seq_len
=
max_encoder_seq_len_left
,
num_encoder_tokens
=
num_encoder_tokens_left
,
cross_slot_mapping
=
cross_slot_mapping_left
,
cross_block_tables
=
cross_block_tables_left
,
)
model_input_left
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
0
],
input_positions
=
split_input_positions
[
0
],
token_types
=
None
,
seq_lens
=
seq_lens_left
,
query_lens
=
query_lens_left
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_left
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_left
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_left
,
selected_token_indices
=
selected_token_indices_left
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_left
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
attn_metadata_right
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
1
],
max_decode_seq_len
=
max_decode_seq_len_right
,
block_tables
=
split_block_tables
[
1
],
num_prefills
=
num_prefills_right
,
num_prefill_tokens
=
num_prefill_tokens_right
,
num_decode_tokens
=
num_decode_tokens_right
,
slot_mapping
=
split_slot_mapping
[
1
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_right
,
max_prefill_seq_len
=
max_prefill_seq_len_right
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_right
,
query_start_loc
=
query_start_loc_right
,
seq_start_loc
=
seq_start_loc_right
,
context_lens_tensor
=
split_context_lens_tensor
[
1
],
max_decode_query_len
=
max_decode_query_len_right
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_right
,
encoder_seq_lens
=
encoder_seq_lens_right
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_right
,
max_encoder_seq_len
=
max_encoder_seq_len_right
,
num_encoder_tokens
=
num_encoder_tokens_right
,
cross_slot_mapping
=
cross_slot_mapping_right
,
cross_block_tables
=
cross_block_tables_right
,
)
model_input_right
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
1
],
input_positions
=
split_input_positions
[
1
],
token_types
=
None
,
seq_lens
=
seq_lens_right
,
query_lens
=
query_lens_right
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_right
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_right
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_right
,
selected_token_indices
=
selected_token_indices_right
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_right
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
return
model_input_left
,
model_input_right
def
merge_model_output
(
states_left
,
states_right
):
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
def
tbo_model_executable
(
model_input
,
vllm_config
,
virtual_engine
,
model_executable
,
intermediate_tensors
,
multi_modal_kwargs
,
self_device
,
seqlen_agnostic_kwargs
,
model_kwargs
,
):
init_two_batch_overlap
()
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
batch_size
=
len
(
model_input
.
attn_metadata
.
seq_lens
)
if
batch_size
==
1
or
\
(
not
model_input
.
is_prompt
and
not
enable_tbo_decode
)
or
\
not
is_rocm_fa
or
\
is_cuda_graph_decode
:
with
set_forward_context
(
model_input
.
attn_metadata
,
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self_device
),
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
)
return
hidden_or_intermediate_states
tbo_obj
.
tbo_running
=
True
tbo_obj
.
left_first
=
True
batch_size_left
=
int
(
batch_size
/
2
)
batch_size_right
=
batch_size_left
if
batch_size
%
2
==
1
:
batch_size_right
+=
1
model_input_left
,
model_input_right
=
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
)
tbo_obj
.
set_model_input
(
model_input_left
,
model_input_right
,
vllm_config
,
virtual_engine
,
model_executable
,
intermediate_tensors
,
multi_modal_kwargs
,
self_device
,
seqlen_agnostic_kwargs
,
model_kwargs
)
tbo_obj
.
all_reduce
()
states_left
,
states_right
=
tbo_obj
.
get_model_output
()
hidden_or_intermediate_states
=
merge_model_output
(
states_left
,
states_right
)
tbo_obj
.
tbo_running
=
False
return
hidden_or_intermediate_states
vllm/worker/model_runner.py
View file @
2a935929
...
...
@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.two_batch_overlap.two_batch_overlap
import
is_enable_tbo
,
tbo_model_executable
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_pin_memory_available
,
supports_dynamo
,
...
...
@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"query_lens"
:
self
.
query_lens
,
"lora_requests"
:
self
.
lora_requests
,
"lora_mapping"
:
self
.
lora_mapping
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
...
...
@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"virtual_engine"
:
self
.
virtual_engine
,
"request_ids_to_seq_ids"
:
self
.
request_ids_to_seq_ids
,
"finished_requests_ids"
:
self
.
finished_requests_ids
,
"is_prompt"
:
self
.
is_prompt
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
...
...
@@ -1776,17 +1779,29 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start
.
record
()
if
not
bypass_model_exec
:
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
)
if
is_enable_tbo
():
hidden_or_intermediate_states
=
tbo_model_executable
(
model_input
,
self
.
vllm_config
,
virtual_engine
,
model_executable
,
intermediate_tensors
,
multi_modal_kwargs
,
self
.
device
,
seqlen_agnostic_kwargs
,
model_kwargs
)
else
:
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
...
...
vllm/worker/worker_base.py
View file @
2a935929
...
...
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
run_method
,
update_environment_variables
,
...
...
@@ -77,7 +78,6 @@ class WorkerBase:
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
def
init_device
(
self
)
->
None
:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
...
...
@@ -113,6 +113,7 @@ class WorkerBase:
while
True
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
if
output
is
None
:
finish_two_batch_overlap
()
return
None
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
...
...
vllm/zero_overhead/llm_engine.py
View file @
2a935929
...
...
@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
,
zero_overhead_stream
logger
=
init_logger
(
__name__
)
...
...
@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine):
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
self
.
thread_running
=
False
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
...
@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine):
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
thread_running
=
False
self
.
q_recorder
=
queue
.
Queue
()
self
.
use_stream
=
zero_overhead_stream
(
self
.
model_executor
.
device_config
.
device
)
if
not
is_zero_no_thread
():
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
thread_running
=
True
...
...
@@ -271,73 +272,78 @@ class ZeroOverheadEngine(LLMEngine):
if
self
.
thread_running
:
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
last_sampler
=
get_last_sampler
()
last_sampler
.
seq_ids
.
clear
()
try
:
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
None
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
last_sampler
=
get_last_sampler
(
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
None
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
last_sampler
=
get_last_sampler
()
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
...
...
@@ -560,14 +566,15 @@ class ZeroOverheadEngine(LLMEngine):
return
ctx
.
request_outputs
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
is_zero_no_thread
():
out
=
self
.
no_thread_step
()
if
out
is
None
:
#the first step need launch twice
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
if
is_zero_no_thread
():
out
=
self
.
no_thread_step
()
else
:
out
=
self
.
zero_overh
ead_step
()
if
out
is
None
:
#the first step need launch twice
if
out
is
None
:
#the first step need launch twice
out
=
self
.
no_thr
ead_step
()
else
:
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
def
_add_processed_request
(
...
...
vllm/zero_overhead/utils.py
View file @
2a935929
...
...
@@ -2,6 +2,7 @@
from
enum
import
Enum
import
os
import
torch
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
...
...
@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids):
def
get_accepted_token_ids
():
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream
=
{}
def
zero_overhead_stream
(
target_device
):
"""Asynchronously create a tensor and copy it from host to device."""
if
target_device
not
in
alloc_stream
.
keys
():
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
return
alloc_stream
[
target_device
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment