Commit 2a935929 authored by lizhigong's avatar lizhigong
Browse files

修复zero-overhead首字正确性问题,zero-overhead不使用默认流调整,增加two-batch-overlap功能

parent cf1d8464
...@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs, ...@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind) resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError from vllm.worker.model_runner_base import InputProcessingError
from vllm.profiler.prof import profile
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
...@@ -413,6 +414,7 @@ class LLMEngine: ...@@ -413,6 +414,7 @@ class LLMEngine:
# Flag to set when an input fails to process and the engine should run # Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling. # the next step without re-scheduling.
self._skip_scheduling_next_step = False self._skip_scheduling_next_step = False
profile.StartTracer()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
......
...@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest ...@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -143,6 +144,7 @@ class ExecutorBase(ABC): ...@@ -143,6 +144,7 @@ class ExecutorBase(ABC):
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop.""" """Releases parallel workers from model loop."""
finish_two_batch_overlap()
return return
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase): ...@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase):
return driver_outputs return driver_outputs
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
finish_two_batch_overlap()
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
return return
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, ...@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -28,6 +30,9 @@ forward_start_time: float = 0 ...@@ -28,6 +30,9 @@ forward_start_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
def is_enable_tbo():
return enable_tbo
@dataclass @dataclass
class DPMetadata: class DPMetadata:
...@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None ...@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext: def get_forward_context() -> ForwardContext:
if is_enable_tbo():
forward_context = get_tbo_forward_context()
"""Get the current forward context."""
assert forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return forward_context
"""Get the current forward context.""" """Get the current forward context."""
assert _forward_context is not None, ( assert _forward_context is not None, (
"Forward context is not set. " "Forward context is not set. "
...@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any, ...@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any,
kv_connector = get_kv_transfer_group() kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1) assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context) kv_connector.start_load_kv(_forward_context)
if is_enable_tbo():
set_tbo_forward_context(_forward_context)
try: try:
yield yield
finally: finally:
...@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any, ...@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector.wait_for_save() kv_connector.wait_for_save()
_forward_context = prev_context _forward_context = prev_context
if is_enable_tbo():
set_tbo_forward_context(_forward_context)
...@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase): ...@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -1307,7 +1310,10 @@ class RowParallelLinear(LinearBase): ...@@ -1307,7 +1310,10 @@ class RowParallelLinear(LinearBase):
input_parallel, input_parallel,
bias=bias_) bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel) if self.enable_tbo:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output = output_parallel
......
...@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_padded, self.num_embeddings_padded,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
@classmethod @classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
...@@ -434,7 +437,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -434,7 +437,10 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel) if self.enable_tbo:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
return output return output
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
import threading
_forward_context_left = None
_forward_context_right = None
_left_tid = 0
_right_tid = 0
def init_tbo_forward_context(left_flag, tid):
global _left_tid
global _right_tid
if left_flag:
_left_tid = tid
else:
_right_tid = tid
def set_tbo_forward_context(_forward_context):
global _forward_context_left
global _forward_context_right
tid = threading.get_ident()
if tid == _left_tid:
_forward_context_left = _forward_context
else:
_forward_context_right = _forward_context
def get_tbo_forward_context():
tid = threading.get_ident()
if tid == _left_tid:
return _forward_context_left
else:
return _forward_context_right
This diff is collapsed.
...@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import is_enable_tbo, tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists, async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo, is_pin_memory_available, supports_dynamo,
...@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"query_lens": self.query_lens,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
...@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
"is_prompt": self.is_prompt,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
...@@ -1776,17 +1779,29 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1776,17 +1779,29 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start.record() model_forward_start.record()
if not bypass_model_exec: if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata, if is_enable_tbo():
self.vllm_config, virtual_engine): hidden_or_intermediate_states = tbo_model_executable(
hidden_or_intermediate_states = model_executable( model_input,
input_ids=model_input.input_tokens, self.vllm_config,
positions=model_input.input_positions, virtual_engine,
intermediate_tensors=intermediate_tensors, model_executable,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, intermediate_tensors,
device=self.device), multi_modal_kwargs,
**seqlen_agnostic_kwargs, self.device,
**model_kwargs, seqlen_agnostic_kwargs,
) model_kwargs)
else:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
......
...@@ -18,6 +18,7 @@ from vllm.logger import init_logger ...@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method, resolve_obj_by_qualname, run_method,
update_environment_variables, update_environment_variables,
...@@ -77,7 +78,6 @@ class WorkerBase: ...@@ -77,7 +78,6 @@ class WorkerBase:
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.current_platform = current_platform self.current_platform = current_platform
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
memory allocations. memory allocations.
...@@ -113,6 +113,7 @@ class WorkerBase: ...@@ -113,6 +113,7 @@ class WorkerBase:
while True: while True:
output = self.execute_model(execute_model_req=None) output = self.execute_model(execute_model_req=None)
if output is None: if output is None:
finish_two_batch_overlap()
return None return None
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
......
...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer ...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step, zero_overhead_stream
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine):
self.log_stats = log_stats self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs self.use_cached_outputs = use_cached_outputs
self.thread_running = False
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
...@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine):
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
self.thread_running = False
self.q_recorder = queue.Queue() self.q_recorder = queue.Queue()
self.use_stream = zero_overhead_stream(self.model_executor.device_config.device)
if not is_zero_no_thread(): if not is_zero_no_thread():
self.zero_thread = threading.Thread(target=self.thread_zero_overhead) self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True self.thread_running = True
...@@ -271,73 +272,78 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -271,73 +272,78 @@ class ZeroOverheadEngine(LLMEngine):
if self.thread_running: if self.thread_running:
self.thread_running = False self.thread_running = False
self.sem_m2s.release() self.sem_m2s.release()
def thread_zero_overhead(self): def thread_zero_overhead(self):
logger.info('zero overhead thread start!') logger.info('zero overhead thread start!')
last_sampler = get_last_sampler()
last_sampler.seq_ids.clear()
try: try:
while True: with torch.cuda.stream(self.use_stream):
self.sem_m2s.acquire() while True:
if not self.thread_running: self.sem_m2s.acquire()
logger.debug("Stopping remote worker execution loop.") if not self.thread_running:
self.model_executor.stop_remote_worker_execution_loop() logger.debug("Stopping remote worker execution loop.")
break self.model_executor.stop_remote_worker_execution_loop()
virtual_engine = 0 break
# Clear outputs for each new scheduler iteration
virtual_engine = 0
# Schedule iteration # Clear outputs for each new scheduler iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc # Schedule iteration
) = self.scheduler[virtual_engine].schedule() (seq_group_metadata_list, scheduler_outputs,
if self.last_record is not None: allow_async_output_proc
last_sampler = self.last_record[1] ) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
for output in outputs:
self._advance_to_next_step(
output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = None
spec_step = get_spec_step() spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT: if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True) last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE: elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True) last_sampler, _ = get_accepted_token_ids()
self.async_event.record() self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
for output in outputs:
self._advance_to_next_step(
output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = None
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE:
last_sampler, _ = get_accepted_token_ids()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
except Exception as e: except Exception as e:
print(f"thread_zero_overhead error : {e}") print(f"thread_zero_overhead error : {e}")
...@@ -560,14 +566,15 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -560,14 +566,15 @@ class ZeroOverheadEngine(LLMEngine):
return ctx.request_outputs return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if is_zero_no_thread(): with torch.cuda.stream(self.use_stream):
out = self.no_thread_step() if is_zero_no_thread():
if out is None: #the first step need launch twice
out = self.no_thread_step() out = self.no_thread_step()
else: if out is None: #the first step need launch twice
out = self.zero_overhead_step() out = self.no_thread_step()
if out is None: #the first step need launch twice else:
out = self.zero_overhead_step() out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out return out
def _add_processed_request( def _add_processed_request(
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from enum import Enum from enum import Enum
import os import os
import torch
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
...@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids): ...@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids):
def get_accepted_token_ids(): def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream = {}
def zero_overhead_stream(target_device):
"""Asynchronously create a tensor and copy it from host to device."""
if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment