Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a935929
Commit
2a935929
authored
May 16, 2025
by
lizhigong
Browse files
修复zero-overhead首字正确性问题,zero-overhead不使用默认流调整,增加two-batch-overlap功能
parent
cf1d8464
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
649 additions
and
83 deletions
+649
-83
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+3
-0
vllm/forward_context.py
vllm/forward_context.py
+17
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+7
-1
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+7
-1
vllm/two_batch_overlap/forward_context.py
vllm/two_batch_overlap/forward_context.py
+35
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+465
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+26
-11
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-1
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+75
-68
vllm/zero_overhead/utils.py
vllm/zero_overhead/utils.py
+10
-0
No files found.
vllm/engine/llm_engine.py
View file @
2a935929
...
@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs,
...
@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname
,
weak_bind
)
resolve_obj_by_qualname
,
weak_bind
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.worker.model_runner_base
import
InputProcessingError
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
@@ -413,6 +414,7 @@ class LLMEngine:
...
@@ -413,6 +414,7 @@ class LLMEngine:
# Flag to set when an input fails to process and the engine should run
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
# the next step without re-scheduling.
self
.
_skip_scheduling_next_step
=
False
self
.
_skip_scheduling_next_step
=
False
profile
.
StartTracer
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
...
vllm/executor/executor_base.py
View file @
2a935929
...
@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
...
@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
...
@@ -143,6 +144,7 @@ class ExecutorBase(ABC):
...
@@ -143,6 +144,7 @@ class ExecutorBase(ABC):
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
"""Releases parallel workers from model loop."""
"""Releases parallel workers from model loop."""
finish_two_batch_overlap
()
return
return
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase):
...
@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase):
return
driver_outputs
return
driver_outputs
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
finish_two_batch_overlap
()
if
self
.
parallel_worker_tasks
is
None
:
if
self
.
parallel_worker_tasks
is
None
:
return
return
...
...
vllm/forward_context.py
View file @
2a935929
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
...
@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group
)
is_v1_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.two_batch_overlap.forward_context
import
get_tbo_forward_context
,
set_tbo_forward_context
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -28,6 +30,9 @@ forward_start_time: float = 0
...
@@ -28,6 +30,9 @@ forward_start_time: float = 0
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_logging_interval
:
float
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
enable_tbo
=
os
.
environ
.
get
(
'VLLM_ENABLE_TBO'
)
==
'1'
def
is_enable_tbo
():
return
enable_tbo
@
dataclass
@
dataclass
class
DPMetadata
:
class
DPMetadata
:
...
@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None
...
@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None
def
get_forward_context
()
->
ForwardContext
:
def
get_forward_context
()
->
ForwardContext
:
if
is_enable_tbo
():
forward_context
=
get_tbo_forward_context
()
"""Get the current forward context."""
assert
forward_context
is
not
None
,
(
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return
forward_context
"""Get the current forward context."""
"""Get the current forward context."""
assert
_forward_context
is
not
None
,
(
assert
_forward_context
is
not
None
,
(
"Forward context is not set. "
"Forward context is not set. "
...
@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any,
...
@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any,
kv_connector
=
get_kv_transfer_group
()
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase_V1
)
assert
isinstance
(
kv_connector
,
KVConnectorBase_V1
)
kv_connector
.
start_load_kv
(
_forward_context
)
kv_connector
.
start_load_kv
(
_forward_context
)
if
is_enable_tbo
():
set_tbo_forward_context
(
_forward_context
)
try
:
try
:
yield
yield
finally
:
finally
:
...
@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any,
...
@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector
.
wait_for_save
()
kv_connector
.
wait_for_save
()
_forward_context
=
prev_context
_forward_context
=
prev_context
if
is_enable_tbo
():
set_tbo_forward_context
(
_forward_context
)
vllm/model_executor/layers/linear.py
View file @
2a935929
...
@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase):
})
})
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -1307,7 +1310,10 @@ class RowParallelLinear(LinearBase):
...
@@ -1307,7 +1310,10 @@ class RowParallelLinear(LinearBase):
input_parallel
,
input_parallel
,
bias
=
bias_
)
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
if
self
.
enable_tbo
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
2a935929
...
@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
num_embeddings_padded
,
self
.
num_embeddings_padded
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
@
classmethod
@
classmethod
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
def
_get_indices
(
cls
,
vocab_size_padded
:
int
,
org_vocab_size_padded
:
int
,
...
@@ -434,7 +437,10 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -434,7 +437,10 @@ class VocabParallelEmbedding(torch.nn.Module):
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
if
self
.
enable_tbo
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
return
output
return
output
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/two_batch_overlap/forward_context.py
0 → 100644
View file @
2a935929
import
threading
_forward_context_left
=
None
_forward_context_right
=
None
_left_tid
=
0
_right_tid
=
0
def
init_tbo_forward_context
(
left_flag
,
tid
):
global
_left_tid
global
_right_tid
if
left_flag
:
_left_tid
=
tid
else
:
_right_tid
=
tid
def
set_tbo_forward_context
(
_forward_context
):
global
_forward_context_left
global
_forward_context_right
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
_forward_context_left
=
_forward_context
else
:
_forward_context_right
=
_forward_context
def
get_tbo_forward_context
():
tid
=
threading
.
get_ident
()
if
tid
==
_left_tid
:
return
_forward_context_left
else
:
return
_forward_context_right
vllm/two_batch_overlap/two_batch_overlap.py
0 → 100644
View file @
2a935929
This diff is collapsed.
Click to expand it.
vllm/worker/model_runner.py
View file @
2a935929
...
@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
...
@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager
)
LRUCacheWorkerPromptAdapterManager
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.two_batch_overlap.two_batch_overlap
import
is_enable_tbo
,
tbo_model_executable
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
from
vllm.utils
import
(
DeviceMemoryProfiler
,
GiB_bytes
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
async_tensor_h2d
,
flatten_2d_lists
,
is_pin_memory_available
,
supports_dynamo
,
is_pin_memory_available
,
supports_dynamo
,
...
@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
...
@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
tensor_dict
=
{
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"input_positions"
:
self
.
input_positions
,
"query_lens"
:
self
.
query_lens
,
"lora_requests"
:
self
.
lora_requests
,
"lora_requests"
:
self
.
lora_requests
,
"lora_mapping"
:
self
.
lora_mapping
,
"lora_mapping"
:
self
.
lora_mapping
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
...
@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
...
@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"virtual_engine"
:
self
.
virtual_engine
,
"virtual_engine"
:
self
.
virtual_engine
,
"request_ids_to_seq_ids"
:
self
.
request_ids_to_seq_ids
,
"request_ids_to_seq_ids"
:
self
.
request_ids_to_seq_ids
,
"finished_requests_ids"
:
self
.
finished_requests_ids
,
"finished_requests_ids"
:
self
.
finished_requests_ids
,
"is_prompt"
:
self
.
is_prompt
,
}
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
...
@@ -1776,17 +1779,29 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1776,17 +1779,29 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start
.
record
()
model_forward_start
.
record
()
if
not
bypass_model_exec
:
if
not
bypass_model_exec
:
with
set_forward_context
(
model_input
.
attn_metadata
,
if
is_enable_tbo
():
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
tbo_model_executable
(
hidden_or_intermediate_states
=
model_executable
(
model_input
,
input_ids
=
model_input
.
input_tokens
,
self
.
vllm_config
,
positions
=
model_input
.
input_positions
,
virtual_engine
,
intermediate_tensors
=
intermediate_tensors
,
model_executable
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
intermediate_tensors
,
device
=
self
.
device
),
multi_modal_kwargs
,
**
seqlen_agnostic_kwargs
,
self
.
device
,
**
model_kwargs
,
seqlen_agnostic_kwargs
,
)
model_kwargs
)
else
:
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
**
seqlen_agnostic_kwargs
,
**
model_kwargs
,
)
if
(
self
.
observability_config
is
not
None
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
and
self
.
observability_config
.
collect_model_forward_time
):
...
...
vllm/worker/worker_base.py
View file @
2a935929
...
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
...
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.two_batch_overlap.two_batch_overlap
import
finish_two_batch_overlap
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
from
vllm.utils
import
(
enable_trace_function_call_for_thread
,
resolve_obj_by_qualname
,
run_method
,
resolve_obj_by_qualname
,
run_method
,
update_environment_variables
,
update_environment_variables
,
...
@@ -77,7 +78,6 @@ class WorkerBase:
...
@@ -77,7 +78,6 @@ class WorkerBase:
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
self
.
current_platform
=
current_platform
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
"""Initialize device state, such as loading the model or other on-device
"""Initialize device state, such as loading the model or other on-device
memory allocations.
memory allocations.
...
@@ -113,6 +113,7 @@ class WorkerBase:
...
@@ -113,6 +113,7 @@ class WorkerBase:
while
True
:
while
True
:
output
=
self
.
execute_model
(
execute_model_req
=
None
)
output
=
self
.
execute_model
(
execute_model_req
=
None
)
if
output
is
None
:
if
output
is
None
:
finish_two_batch_overlap
()
return
None
return
None
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
...
...
vllm/zero_overhead/llm_engine.py
View file @
2a935929
...
@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
...
@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
from
vllm.zero_overhead.utils
import
SpecStepKind
,
get_accepted_token_ids
,
get_spec_step
,
is_zero_no_thread
,
set_spec_step
,
zero_overhead_stream
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine):
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
self
.
use_cached_outputs
=
use_cached_outputs
self
.
thread_running
=
False
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
...
@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine):
self
.
async_d2h
=
None
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
thread_running
=
False
self
.
q_recorder
=
queue
.
Queue
()
self
.
q_recorder
=
queue
.
Queue
()
self
.
use_stream
=
zero_overhead_stream
(
self
.
model_executor
.
device_config
.
device
)
if
not
is_zero_no_thread
():
if
not
is_zero_no_thread
():
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
thread_running
=
True
self
.
thread_running
=
True
...
@@ -271,73 +272,78 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -271,73 +272,78 @@ class ZeroOverheadEngine(LLMEngine):
if
self
.
thread_running
:
if
self
.
thread_running
:
self
.
thread_running
=
False
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
logger
.
info
(
'zero overhead thread start!'
)
last_sampler
=
get_last_sampler
()
last_sampler
.
seq_ids
.
clear
()
try
:
try
:
while
True
:
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
self
.
sem_m2s
.
acquire
()
while
True
:
if
not
self
.
thread_running
:
self
.
sem_m2s
.
acquire
()
logger
.
debug
(
"Stopping remote worker execution loop."
)
if
not
self
.
thread_running
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
logger
.
debug
(
"Stopping remote worker execution loop."
)
break
self
.
model_executor
.
stop_remote_worker_execution_loop
()
virtual_engine
=
0
break
# Clear outputs for each new scheduler iteration
virtual_engine
=
0
# Schedule iteration
# Clear outputs for each new scheduler iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
# Schedule iteration
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
(
seq_group_metadata_list
,
scheduler_outputs
,
if
self
.
last_record
is
not
None
:
allow_async_output_proc
last_sampler
=
self
.
last_record
[
1
]
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
self
.
last_record
is
not
None
:
last_sampler
=
self
.
last_record
[
1
]
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
None
spec_step
=
get_spec_step
()
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
last_sampler
=
get_last_sampler
(
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
async_event
.
record
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
output
in
outputs
:
self
.
_advance_to_next_step
(
output
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
last_sampler
=
None
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
last_sampler
=
get_last_sampler
()
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
last_sampler
,
_
=
get_accepted_token_ids
()
self
.
last_record
=
[
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
,
spec_step
]
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
...
@@ -560,14 +566,15 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -560,14 +566,15 @@ class ZeroOverheadEngine(LLMEngine):
return
ctx
.
request_outputs
return
ctx
.
request_outputs
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
is_zero_no_thread
():
with
torch
.
cuda
.
stream
(
self
.
use_stream
):
out
=
self
.
no_thread_step
()
if
is_zero_no_thread
():
if
out
is
None
:
#the first step need launch twice
out
=
self
.
no_thread_step
()
out
=
self
.
no_thread_step
()
else
:
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overh
ead_step
()
out
=
self
.
no_thr
ead_step
()
if
out
is
None
:
#the first step need launch twice
else
:
out
=
self
.
zero_overhead_step
()
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
return
out
def
_add_processed_request
(
def
_add_processed_request
(
...
...
vllm/zero_overhead/utils.py
View file @
2a935929
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from
enum
import
Enum
from
enum
import
Enum
import
os
import
os
import
torch
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
zero_no_thread
=
os
.
environ
.
get
(
'VLLM_ZERO_NO_THREAD'
)
==
'1'
...
@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids):
...
@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids):
def
get_accepted_token_ids
():
def
get_accepted_token_ids
():
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
return
spec_context
.
accepted_token_ids
,
spec_context
.
accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream
=
{}
def
zero_overhead_stream
(
target_device
):
"""Asynchronously create a tensor and copy it from host to device."""
if
target_device
not
in
alloc_stream
.
keys
():
alloc_stream
[
target_device
]
=
torch
.
cuda
.
Stream
(
device
=
target_device
)
return
alloc_stream
[
target_device
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment