Commit 9bd32639 authored by lizhigong's avatar lizhigong
Browse files

zero overhead engine update

parent 6b7651af
...@@ -61,6 +61,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, ...@@ -61,6 +61,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm.profiler.prof import profile
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
...@@ -408,6 +409,11 @@ class LLMEngine: ...@@ -408,6 +409,11 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.step_switch = 0 # 0 step A 1 step B
self.output_recorder = [None, None]
profile.StartTracer()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -1271,6 +1277,9 @@ class LLMEngine: ...@@ -1271,6 +1277,9 @@ class LLMEngine:
else: else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
def trans_last_output_tensor(self, last_output) -> torch.Tensor:
return None
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
...@@ -1346,6 +1355,7 @@ class LLMEngine: ...@@ -1346,6 +1355,7 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups. # Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current # This ensures that the scheduler is only called again when the current
# batch has completed. # batch has completed.
profile.ProfRangeAutoPush('has_remain')
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration # Schedule iteration
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
...@@ -1375,6 +1385,10 @@ class LLMEngine: ...@@ -1375,6 +1385,10 @@ class LLMEngine:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert scheduler_outputs is not None assert scheduler_outputs is not None
profile.ProfRangeAutoPush('execute_model')
last_outputs = None
if self.zero_overhead:
last_outputs = self.trans_last_output_tensor(self.output_recorder[self.step_switch])
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Check if we have a cached last_output from the previous iteration. # Check if we have a cached last_output from the previous iteration.
...@@ -1384,6 +1398,14 @@ class LLMEngine: ...@@ -1384,6 +1398,14 @@ class LLMEngine:
last_sampled_token_ids = \ last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine) self._get_last_sampled_token_ids(virtual_engine)
# print('seq_group_metadata_list', len(seq_group_metadata_list))
# print('scheduler_outputs.blocks_to_swap_in', len(scheduler_outputs.blocks_to_swap_in))
# print('scheduler_outputs.num_lookahead_slots', scheduler_outputs.num_lookahead_slots)
# print('scheduler_outputs.running_queue_size', scheduler_outputs.running_queue_size)
# print('finished_requests_ids', len(finished_requests_ids))
# print('last_sampled_token_ids', last_sampled_token_ids)
# print('self.model_executor', type(self.model_executor))
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
...@@ -1394,15 +1416,15 @@ class LLMEngine: ...@@ -1394,15 +1416,15 @@ class LLMEngine:
finished_requests_ids=finished_requests_ids, finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids # We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids) last_sampled_token_ids=last_sampled_token_ids,
last_outputs = last_outputs)
if allow_async_output_proc: if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
outputs = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
print('###outputs', outputs)
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
...@@ -1420,6 +1442,15 @@ class LLMEngine: ...@@ -1420,6 +1442,15 @@ class LLMEngine:
for seq_group in seq_group_metadata_list: for seq_group in seq_group_metadata_list:
seq_group.finish_step() seq_group.finish_step()
if self.zero_overhead:
self.output_recorder[self.step_switch] = outputs
self.step_switch = 1 - self.step_switch
outputs = self.output_recorder[self.step_switch]
if outputs is None:
return None
#同步上一次的output
if not self._has_remaining_steps(seq_group_metadata_list): if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps. # clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
...@@ -1460,6 +1491,7 @@ class LLMEngine: ...@@ -1460,6 +1491,7 @@ class LLMEngine:
# Multi-step case # Multi-step case
return ctx.request_outputs return ctx.request_outputs
profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests(): if not self.has_unfinished_requests():
# Drain async postprocessor (if exists) # Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
......
...@@ -1388,6 +1388,8 @@ class LLM: ...@@ -1388,6 +1388,8 @@ class LLM:
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
if step_outputs is None:
continue
for output in step_outputs: for output in step_outputs:
if output.finished: if output.finished:
outputs.append(output) outputs.append(output)
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
\ No newline at end of file
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
...@@ -1402,6 +1402,9 @@ class ExecuteModelRequest( ...@@ -1402,6 +1402,9 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model. # Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs : Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
...@@ -1451,7 +1454,8 @@ class ExecuteModelRequest( ...@@ -1451,7 +1454,8 @@ class ExecuteModelRequest(
async_callback=self.async_callback, async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks, tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids, tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved) kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved,
last_outputs = self.last_outputs)
@dataclass @dataclass
......
# SPDX-License-Identifier: Apache-2.0
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.7.2"
__version_tuple__ = (0, 7, 2)
__hcu_version__ = f'0.7.2+das.opt1.cust1.6b7651a.dtk2504'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
...@@ -59,6 +59,8 @@ from vllm.worker.model_runner_base import ( ...@@ -59,6 +59,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.profiler.prof import profile
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -271,7 +273,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -271,7 +273,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs self.n_seqs = n_seqs
self.encoder_seq_len = encoder_seq_len self.encoder_seq_len = encoder_seq_len
if reinit: if reinit:
if len(self.seq_ids) == 1 and reinit_use_defaults: if len(self.seq_ids) == 1 and reinit_use_defaults:
self.simple_reinit() self.simple_reinit()
...@@ -901,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -901,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
token_types_tensor = async_tensor_h2d(token_types, torch.long, token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) \ self.runner.pin_memory) \
...@@ -1670,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1670,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_prompt_adapters( self.set_active_prompt_adapters(
model_input.prompt_adapter_requests, model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping) model_input.prompt_adapter_mapping)
profile.ProfRangeAutoPush('begin_forward')
self.attn_state.begin_forward(model_input) self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
...@@ -1772,6 +1774,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1772,6 +1774,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
torch.tensor(model_forward_time + orig_model_forward_time)) torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states return hidden_or_intermediate_states
profile.ProfRangeAutoPush('compute_logits')
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
...@@ -1782,10 +1785,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1782,10 +1785,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
profile.ProfRangeAutoPush('sample')
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
profile.ProfRangeAutoPush('sample_end')
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time and self.observability_config.collect_model_forward_time
and output is not None): and output is not None):
...@@ -1803,6 +1808,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1803,6 +1808,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
profile.ProfRangeAutoPush('output')
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
......
...@@ -189,6 +189,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -189,6 +189,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.last_output = None
# Map of request_id -> generator used for seeded random sampling # Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {} generators: Dict[str, torch.Generator] = {}
......
...@@ -25,6 +25,7 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, ...@@ -25,6 +25,7 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
from vllm.profiler.prof import profile
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
worker_input: WorkerInput = self.prepare_worker_input( worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
self.model_runner.last_output = execute_model_req.last_outputs
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
...@@ -444,7 +446,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -444,7 +446,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
and self.observability_config.collect_model_execute_time): and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get( orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item() "model_execute_time", torch.tensor(0)).item()
profile.ProfRangeAutoPush('execute')
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
model_input=model_input, model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine] kv_caches=self.kv_cache[worker_input.virtual_engine]
...@@ -453,6 +455,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -453,6 +455,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
num_steps=num_steps, num_steps=num_steps,
**kwargs, **kwargs,
) )
profile.ProfRangeAutoPush('output')
model_execute_time = time.perf_counter() - start_time model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment