Commit e0624a14 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

V0.8.5 zero overhead

See merge request dcutoolkit/deeplearing/vllm!110
parents a0c212c0 29e922ac
...@@ -239,8 +239,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -239,8 +239,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to( # block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True) # device, non_blocking=True)
block_tables = async_tensor_h2d(input_block_tables.tolist(), torch.int32,
device, self.runner.pin_memory)
else: else:
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
......
...@@ -6,6 +6,8 @@ from contextlib import contextmanager ...@@ -6,6 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
...@@ -79,7 +81,10 @@ class MQLLMEngine: ...@@ -79,7 +81,10 @@ class MQLLMEngine:
# the python object to be reused again. # the python object to be reused again.
kwargs['use_cached_outputs'] = True kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, **kwargs) if is_zero_overhead():
self.engine = ZeroOverheadEngine(*args, **kwargs)
else:
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests self.log_requests = log_requests
self.use_async_sockets = use_async_sockets self.use_async_sockets = use_async_sockets
......
...@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -244,8 +246,12 @@ class LLM: ...@@ -244,8 +246,12 @@ class LLM:
) )
# Create the Engine (autoselects V0 vs V1) # Create the Engine (autoselects V0 vs V1)
self.llm_engine = LLMEngine.from_engine_args( if is_zero_overhead():
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) self.llm_engine = ZeroOverheadEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else:
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine) self.engine_class = type(self.llm_engine)
self.request_counter = Counter() self.request_counter = Counter()
...@@ -1444,6 +1450,7 @@ class LLM: ...@@ -1444,6 +1450,7 @@ class LLM:
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
......
...@@ -21,6 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, ...@@ -21,6 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -38,6 +39,9 @@ def get_sampler() -> torch.nn.Module: ...@@ -38,6 +39,9 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed # Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler() return V1Sampler()
if is_zero_overhead():
from vllm.zero_overhead.sampler import ZeroOverheadSampler
return ZeroOverheadSampler()
return Sampler() return Sampler()
......
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
...@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase ...@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -206,8 +207,14 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -206,8 +207,14 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device # Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle": if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True enable_lm_head_weight_load = True
if is_zero_overhead():
proposer_worker = MultiStepWorker(**draft_worker_kwargs) assert False, (
"speculative decoding not support zero overhead scheduler yet"
)
from vllm.zero_overhead.spec_decode.muti_step_worker import ZeroOverheadMultiStepWorker
proposer_worker = ZeroOverheadMultiStepWorker(**draft_worker_kwargs)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp": if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = \ num_spec_prefill_steps = \
draft_model_config.hf_config.n_predict draft_model_config.hf_config.n_predict
...@@ -254,17 +261,31 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -254,17 +261,31 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.") "target model is not running in eager mode.")
return SpecDecodeWorker( if is_zero_overhead():
proposer_worker, from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
scorer_worker, return ZeroOverheadSpecDecodeWorker(
disable_mqa_scorer=disable_mqa_scorer, proposer_worker,
disable_logprobs=disable_logprobs, scorer_worker,
disable_log_stats=disable_log_stats, disable_mqa_scorer=disable_mqa_scorer,
disable_by_batch_size=disable_by_batch_size, disable_logprobs=disable_logprobs,
spec_decode_sampler=spec_decode_sampler, disable_log_stats=disable_log_stats,
allow_zero_draft_token_step=allow_zero_draft_token_step, disable_by_batch_size=disable_by_batch_size,
enable_lm_head_weight_load=enable_lm_head_weight_load, spec_decode_sampler=spec_decode_sampler,
num_spec_prefill_steps=num_spec_prefill_steps) allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
else:
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
def __init__( def __init__(
self, self,
......
...@@ -60,6 +60,7 @@ from vllm.worker.model_runner_base import ( ...@@ -60,6 +60,7 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.utils import is_zero_overhead
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -1636,6 +1637,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1636,6 +1637,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
......
from functools import partial
import os
import queue
import threading
import traceback
from typing import Callable, Dict, List, Mapping, Optional, Type, Union
from zlib import ZLIB_VERSION
import torch
from vllm import envs
from vllm.config import DecodingConfig, ObservabilityConfig, VllmConfig
from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.logger import init_logger
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import ProcessorInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.inputs.registry import InputRegistry
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.registry import MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.version import __version__ as VLLM_VERSION
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.sequence import ZeroOverheadSequence
from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step
logger = init_logger(__name__)
class ZeroOverheadEngine(LLMEngine):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
VLLM_VERSION,
vllm_config,
use_cached_outputs,
)
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = ZeroOverheadDetokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: ZeroOverheadSequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
self.model_config)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(self.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(self.model_config.dtype),
"tensor_parallel_size":
self.parallel_config.tensor_parallel_size,
"block_size":
self.cache_config.block_size,
"gpu_memory_utilization":
self.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
self.model_config.quantization,
"kv_cache_dtype":
str(self.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if self.model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
vllm_config=vllm_config),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(
model_name=self.model_config.served_model_name),
vllm_config=vllm_config),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=ZeroOverheadStopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
self.thread_running = False
self.q_recorder = queue.Queue()
if not is_zero_no_thread():
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start()
profile.StartTracer()
def __del__(self):
self.finish_thread()
return super().__del__()
def finish_thread(self):
if self.thread_running:
self.thread_running = False
self.sem_m2s.release()
def thread_zero_overhead(self):
logger.info('zero overhead thread start!')
try:
while True:
self.sem_m2s.acquire()
if not self.thread_running:
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
break
virtual_engine = 0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
for output in outputs:
self._advance_to_next_step(
output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = None
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE:
last_sampler, _ = get_accepted_token_ids()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
traceback.print_exc()
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if not self.thread_running:
self.zero_thread.join()
self.thread_running = True
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.zero_thread.start()
self.sem_m2s.release()
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_event.synchronize()
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_event.synchronize()
self._fix_spec_decode_steps(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
# logger.debug("Stopping remote worker execution loop.")
# self.model_executor.stop_remote_worker_execution_loop()
self.finish_thread()
return ctx.request_outputs
def _fix_last_step(
self, output: List[SamplerOutput],
last_sampler: SampleRecorder,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
if seq_group_metadata.do_sample:
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id:
if type(token_id) is list:
sample.output_token = token_id[0]
else:
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token)
break
def _fix_spec_decode_steps(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]):
sample_out_list = self.async_d2h.tolist()
group_idx = 0
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
group_idx += 1
continue
if seq_group_metadata.do_sample:
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
remove_count = 0
for token_id in accept_token_ids:
if token_id == -1:
remove_count += 1
else:
seq.fix_last_token_id(token_id)
seq.remove_last_place_holder(remove_count)
group_idx += 1
def no_thread_step(self):
virtual_engine = 0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
else:
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
if len(outputs) == 1:
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if is_zero_no_thread():
out = self.no_thread_step()
if out is None: #the first step need launch twice
out = self.no_thread_step()
else:
out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out
def _add_processed_request(
self,
request_id: str,
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
processed_inputs=processed_inputs,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = ZeroOverheadSequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = (None if encoder_inputs is None else ZeroOverheadSequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group
\ No newline at end of file
import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
accepted_req_ids,
accepted_req_ids_len,
accepted_token_ids,
accepted_token_len,
chidren_req_ids,
chidren_req_ids_len,
input_tokens,
input_tokens_len,
input_positions,
seq_lens,
seq_lens_meta,
seq_lens_tensor,
slot_mapping,
seq_start_loc,
context_lens_tensor,
):
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
for seq_id_idx in range(chidren_req_ids_len / 2):
seq_id = chidren_req_ids_[2 * seq_id_idx]
for i in range(accepted_req_ids_len):
if seq_id == accepted_req_ids_[i]:
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
accepted_token_counter = 0
for j in range(accepted_token_len):
if accepted_token_ids_[j] == -1:
break
accepted_token_counter += 1
if accepted_token_counter == accepted_token_len:
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
else:
tl.store(input_tokens + seq_id_idx * 2, 0)
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
input_pos[0] = 0
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = -1
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = 1
input_pos[1] = input_pos[1] + 1
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
seq_start_loc_ = tl.zero_like(seq_start_loc)
for i in range(input_tokens_len):
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def __init__(self, runner, finished_requests_ids = None):
super().__init__(runner, finished_requests_ids)
self.req_ids = []
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.req_ids.clear()
return super().prepare(finished_requests_ids)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
return super().add_seq_group(seq_group_metadata)
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
spec_step = get_spec_step()
last_step = get_spec_last_step()
if last_sampler is not None:
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
if spec_step == SpecStepKind.OTHER_PROPOSAL:
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if spec_step == SpecStepKind.SCORE_DECODE:
proposal_token_ids = get_proposal_token_ids()
shape = proposal_token_ids.shape
batch_size = shape[0]
proposal_len = shape[1]
update_indices = []
for i in range(batch_size):
for j in range(proposal_len):
update_indices.append(i * (proposal_len + 1) + j + 1)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
if spec_step == SpecStepKind.FIRST_PROPOSAL:
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
grid = [1, 1, 1]
_update_input_tokens[grid](
accept_seq_ids, accept_seq_ids.shape[0],
accept_token_ids, accept_token_ids.shape[1],
chidren_req_ids, chidren_req_ids.shape[0],
model_input.input_tokens, model_input.input_tokens.shape[0],
model_input.input_positions,
model_input.seq_lens,
model_input.attn_metadata.seq_lens_tensor,
model_input.attn_metadata.seq_lens,
model_input.attn_metadata.slot_mapping,
model_input.attn_metadata.seq_start_loc,
model_input.attn_metadata.context_lens_tensor,
)
return model_input
from importlib.util import find_spec
from typing import Dict, List, Optional
import torch
from vllm import envs
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs, _multinomial
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType
from vllm.sequence import VLLM_INVALID_TOKEN_ID
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
class SampleRecorder:
def __init__(self):
self.seq_ids:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
last_sampler = None
def get_last_sampler():
return last_sampler
class ZeroOverheadSampler(Sampler):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
global last_sampler
last_sampler = SampleRecorder()
assert logits is not None
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: SampleResultsDictType = {}
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_n_in_batch = 1
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)
if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
last_sampler.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return (
maybe_deferred_args,
sampled_token_ids_tensor,
)
def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata,
sampling_metadata,
greedy_samples,
multinomial_samples,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
\ No newline at end of file
from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence
class ZeroOverheadSequence(Sequence):
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
self.effective_output_len : int = 0
def fix_last_token_id(self, token_id: int) -> None:
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
assert effect_offset < 0
self.data._output_token_ids[effect_offset] = token_id
if len(self.data._new_appended_tokens) >= effect_offset * -1:
self.data._new_appended_tokens[effect_offset] = token_id
self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1
def remove_last_place_holder(self, count):
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
self.data._num_computed_tokens -= count
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len]
def zero_overhead_get_output_len(self) -> int:
return self.effective_output_len
def zero_overhead_get_last_token_id(self) -> int:
if self.effective_output_len == 0:
return self.data._prompt_token_ids[-1]
return self.data._output_token_ids[self.effective_output_len - 1]
def zero_overhead_get_len(self) -> int:
return self.effective_output_len + len(self.data._prompt_token_ids)
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.zero_overhead_get_output_token_ids()
output_len = self.zero_overhead_get_output_len()
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[self.effective_output_len - 1]
if num_new_tokens == 0:
return []
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
\ No newline at end of file
from array import array
import numpy as np
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list = get_proposal_lens_list()
record_proposal_token_ids(proposals.proposal_token_ids)
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
self._device,
True)
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
self._device,
True)
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
\ No newline at end of file
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker_base import DelegateWorkerBase
class ZeroOverheadMultiStepWorker(MultiStepWorker):
def init_device(self) -> None:
self.worker.init_device()
self._proposer = ZeroOverheadTop1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
set_spec_step(SpecStepKind.SCORE_DECODE)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
def _filter_model_output_zero_overhead(self,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
self.device,
True)
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
\ No newline at end of file
import os
import copy
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (Timer, create_logprobs_output,
create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
logger = init_logger(__name__)
class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_device()
self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)
self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
if model_parallel_is_initialized():
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
scorer_cls = ZeroOverheadBatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
if not self.tree_decoding:
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
else:
self.scorer = BatchExpansionTreeStyleScorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
set_spec_step(SpecStepKind.PREFILL)
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden = [
sg for sg in execute_model_req.seq_group_metadata_list
if sg.do_sample
]
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
self.previous_logits = Logits(
logits, execute_model_req.seq_group_metadata_list)
else:
self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
[sampler_output])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return sampler_output_to_return
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens.
if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_decoding:
retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if proposal_probs is not None and non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers.
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if cart_candidates is not None and non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
sampler_extra_kwargs["cart_candidates"] = cart_candidates
sampler_extra_kwargs["best_candidates"] = []
sampler_extra_kwargs["accept_lengths"] = []
first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_first_step_flag() else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags
accepted_token_ids = self.spec_decode_sampler(
target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if not self.tree_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
else:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
original_indices = async_tensor_h2d(original_indices, torch.int32,
self.device,
True)
accepted_token_ids[original_indices] = accepted_token_ids.clone()
# B x K+1 x D
hidden_states = proposal_scores.hidden_states
select_indices = None
accept_lengths = None
select_indices_list = []
if cart_candidates is None:
if hidden_states is not None:
# Only get terminal hidden states for next step
terminal_metadata = [
sg for sg in seq_group_metadata_list if sg.do_sample
]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,
hs_size) # b x 1 x d
second_last_token_hidden_states = hidden_states[:, -2] # b x d
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(
hidden_states, terminal_metadata,
second_last_token_hidden_states)
else:
retrieve_indices = proposals.retrieve_indices
batch_size = len(seq_group_metadata_list)
best_candidates = sampler_extra_kwargs["best_candidates"]
accept_lengths = sampler_extra_kwargs["accept_lengths"]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(batch_size, -1, hs_size)
# Store logits from target model for subsequent proposal
logits = proposal_scores.logits
logits = logits.view(batch_size, -1, logits.shape[-1])
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list = []
previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
previous_logits_list.append(logit)
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
select_indices_list.append(select_indices)
previous_hidden_state_list.append(hidden_state)
logits = torch.cat(previous_logits_list, dim=0)
self.previous_logits = Logits(logits, seq_group_metadata_list)
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
self.previous_hidden_states = HiddenStates(hidden_states,
seq_group_metadata_list,)
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
prompt_logprobs: Optional[
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize tensor to CPU Python list.
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
record_accepted_token_ids(accepted_token_ids, seq_ids)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list: List[SamplerOutput] = []
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for i, sg in enumerate(seq_group_metadata_list):
if not sg.is_prompt:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs = num_logprobs_per_seq[i]
seq_kwargs = dict(token_id=-1,
token_id_logprob_rank=0,
token_id_logprob=-float('inf'),
topk_token_ids=[-1] * num_logprobs,
topk_logprobs=[-float('inf')] * num_logprobs,
seq_id=seq_ids[i])
# Terminal chunk, has token.
if sg.do_sample:
seq_kwargs.update(
dict(
token_id=accepted_token_ids[i][0].item(),
token_id_logprob_rank=accepted_token_id_ranks_by_step[
0][i],
token_id_logprob=accepted_token_id_logprobs_by_step[0]
[i],
topk_token_ids=topk_indices_by_step[0][i]
[:num_logprobs],
# output only so step is 0
topk_logprobs=topk_logprobs_by_step[0][i]
[:num_logprobs],
))
needs_plogs = (sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
plogs = None
if prompt_logprobs is not None:
# Even non-terminal prompt chunks can have logprobs here.
plogs = prompt_logprobs[i]
elif needs_plogs:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data = next(iter(sg.seq_data.values()))
# Get only the tokens in this chunk!
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_token_ids = prompt_token_ids[
seq_data.
_num_computed_tokens:seq_data._num_computed_tokens +
sg.token_chunk_size]
is_first_chunk = seq_data._num_computed_tokens == 0
# There's no prob generated for the first token in a sequence.
if is_first_chunk:
prompt_token_ids = prompt_token_ids[1:]
plogs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
) for p_token_id in prompt_token_ids
]
seq_kwargs.update(dict(prompt_logprobs=plogs))
sampler_output_list.append(
SamplerOutput(
outputs=[create_sequence_group_output(
**seq_kwargs)])) # type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for step_index in range(num_steps):
# if all(token_id == -1 for sg, token_id in zip(
# seq_group_metadata_list,
# accepted_token_ids_by_step[step_index])
# if not sg.is_prompt):
# break
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size):
seq_meta = seq_group_metadata_list[sequence_index]
# Prompts already processed above.
if seq_meta.is_prompt:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
create_sequence_group_output(
token_id = 0,
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self._track_sequences_with_bonus_tokens(seq_ids,
request_ids_seq_ids_mapping,
accepted_token_ids_by_step)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None and sampler_output_list:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return sampler_output_list
def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int],
request_ids_seq_ids_mapping: Dict[str, Set[int]],
accepted_token_ids_by_step: List[List[int]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for seq_index, seq_id in enumerate(seq_ids):
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
# if last_token_id == -1:
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
# else:
self._seq_with_bonus_token_in_last_step.add(seq_id)
for request_id, sequences in request_ids_seq_ids_mapping.items():
self._request_id_seq_id_mapping[request_id].update(sequences)
\ No newline at end of file
import os
from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import record_proposal_lens_list
class ZeroOverheadTop1Proposer(Top1Proposer):
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
proposal_lens_list = [0 for i in range(batch_size)]
for indices in nonzero_proposal_len_indices:
proposal_lens_list[indices] = proposal_len
record_proposal_lens_list(proposal_lens_list)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
self._device,
True)
return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker):
def __init__(self, max_model_len, get_tokenizer_for_seq):
super().__init__(max_model_len, get_tokenizer_for_seq)
def maybe_stop_sequence(
self,
seq: ZeroOverheadSequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer):
def __init__(self, tokenizer_group):
super().__init__(tokenizer_group)
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length = seq.get_prompt_len() + seq.effective_output_len
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
\ No newline at end of file
from enum import Enum
import os
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
def is_zero_overhead():
return zero_overhead
def is_zero_no_thread():
return zero_no_thread and zero_overhead
class SpecStepKind(Enum):
KIND_DEFAULT = 0
PREFILL = 1
FIRST_PROPOSAL = 2
OTHER_PROPOSAL = 3
SCORE_DECODE = 4
class ZeroOverheadSpecContext():
def __init__(self):
self.step_kind = SpecStepKind.KIND_DEFAULT
self.last_step = SpecStepKind.KIND_DEFAULT
self.proposal_lens_list = None
self.proposal_token_ids = None
self.accepted_token_ids = None
self.accepted_seq_ids = None
spec_context = ZeroOverheadSpecContext()
def set_spec_step(_step):
global spec_context
spec_context.last_step = spec_context.step_kind
spec_context.step_kind = _step
def get_spec_step():
return spec_context.step_kind
def get_spec_last_step():
return spec_context.last_step
def record_proposal_lens_list(list):
global spec_context
spec_context.proposal_lens_list = list
def get_proposal_lens_list():
return spec_context.proposal_lens_list
def record_proposal_token_ids(tensor):
global spec_context
spec_context.proposal_token_ids = tensor
def get_proposal_token_ids():
return spec_context.proposal_token_ids
def record_accepted_token_ids(tensor, seq_ids):
global spec_context
spec_context.accepted_token_ids = tensor
spec_context.accepted_seq_ids = seq_ids
def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment