Commit 54294854 authored by lizhigong's avatar lizhigong
Browse files

add v0 zero overhead

parent a0c212c0
...@@ -6,6 +6,8 @@ from contextlib import contextmanager ...@@ -6,6 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
...@@ -79,7 +81,10 @@ class MQLLMEngine: ...@@ -79,7 +81,10 @@ class MQLLMEngine:
# the python object to be reused again. # the python object to be reused again.
kwargs['use_cached_outputs'] = True kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, **kwargs) if is_zero_overhead():
self.engine = ZeroOverheadEngine(*args, **kwargs)
else:
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests self.log_requests = log_requests
self.use_async_sockets = use_async_sockets self.use_async_sockets = use_async_sockets
......
...@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -244,8 +246,12 @@ class LLM: ...@@ -244,8 +246,12 @@ class LLM:
) )
# Create the Engine (autoselects V0 vs V1) # Create the Engine (autoselects V0 vs V1)
self.llm_engine = LLMEngine.from_engine_args( if is_zero_overhead():
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) self.llm_engine = ZeroOverheadEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else:
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine) self.engine_class = type(self.llm_engine)
self.request_counter = Counter() self.request_counter = Counter()
......
...@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, ...@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler
from vllm.zero_overhead.v0.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module: ...@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed # Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler() return V1Sampler()
if is_zero_overhead():
return ZeroOverheadSampler()
return Sampler() return Sampler()
......
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
...@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import ( ...@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder
from vllm.zero_overhead.v0.utils import is_zero_overhead
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
_builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
......
from collections import Counter
from functools import partial
import os
import queue
import threading
import traceback
from typing import Callable, Dict, List, Mapping, Optional, Type, Union
from zlib import ZLIB_VERSION
import torch
from vllm import envs
from vllm.config import DecodingConfig, ObservabilityConfig, VllmConfig
from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.engine.llm_engine import _LOCAL_LOGGING_INTERVAL_SEC, LLMEngine, SchedulerContext, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.entrypoints import logger
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import ProcessorInputs
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.inputs.registry import InputRegistry
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.registry import MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, ParallelSampleSequenceGroup, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
from vllm.zero_overhead.v0.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.v0.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
class ZeroOverheadEngine(LLMEngine):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config # noqa
self.load_config = vllm_config.load_config
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
ZLIB_VERSION,
vllm_config,
use_cached_outputs,
)
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = ZeroOverheadDetokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: ZeroOverheadSequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
self.model_config)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(self.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(self.model_config.dtype),
"tensor_parallel_size":
self.parallel_config.tensor_parallel_size,
"block_size":
self.cache_config.block_size,
"gpu_memory_utilization":
self.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
self.model_config.quantization,
"kv_cache_dtype":
str(self.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(self.lora_config),
"enable_prompt_adapter":
bool(self.prompt_adapter_config),
"enable_prefix_caching":
self.cache_config.enable_prefix_caching,
"enforce_eager":
self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if self.model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
vllm_config=vllm_config),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(
model_name=self.model_config.served_model_name),
vllm_config=vllm_config),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=ZeroOverheadStopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False
self.async_d2h = None
self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False)
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.q_recorder = queue.Queue()
self.thread_running = True
self.sem_m2s = threading.Semaphore(0) # main to scheduler thread
self.zero_thread.start()
profile.StartTracer()
def __del__(self):
self.finish_thread()
return super().__del__()
def finish_thread(self):
if self.thread_running:
self.thread_running = False
self.sem_m2s.release()
def thread_zero_overhead(self):
try:
while True:
self.sem_m2s.acquire()
if not self.thread_running:
break
virtual_engine = 0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
last_outputs_ids = None
last_outputs_tensor = None
if self.last_record is not None:
last_output = self.last_record[0][0]
last_outputs_ids, last_outputs_tensor = last_output.sampler_out_ids, last_output.sampler_out_tenosr
self.async_d2h = last_outputs_tensor.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
self.q_recorder.put(None)
if len(seq_group_metadata_list) == 0:
self.last_record = None
continue
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
if len(outputs) == 1:
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
self.last_record = [outputs, seq_group_metadata_list, scheduler_outputs]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
traceback.print_exc()
def zero_overhead_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if not self.thread_running:
self.zero_thread.join()
self.thread_running = True
self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.zero_thread.start()
self.sem_m2s.release()
recode_output = self.q_recorder.get()
if recode_output is None: # None is for the first step
return None
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, seq_group_metadata_list, scheduler_outputs = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
def _fix_last_step(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist()
sample_out_ids = output[0].sampler_out_ids.tolist()
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
continue
if seq_group_metadata.do_sample:
sample = sequence_group_outputs.samples[0]
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id:
if type(token_id) is list:
sample.output_token = token_id[0]
else:
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token)
break
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
out = self.zero_overhead_step()
if out is None: #the first step need launch twice
out = self.zero_overhead_step()
return out
def _add_processed_request(
self,
request_id: str,
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
processed_inputs=processed_inputs,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = ZeroOverheadSequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = (None if encoder_inputs is None else ZeroOverheadSequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group
\ No newline at end of file
import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def __init__(self, runner, finished_requests_ids = None):
super().__init__(runner, finished_requests_ids)
self.req_ids = []
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.req_ids.clear()
return super().prepare(finished_requests_ids)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
return super().add_seq_group(seq_group_metadata)
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
if last_sampler.sampled_token_ids_tensor is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
last_ids = async_tensor_h2d(last_sampler.seq_id.tolist(), torch.long,
self.runner.device,
self.runner.pin_memory)
UpdateInputTokens(model_input.input_tokens, input_ids, last_sampler.sampled_token_ids_tensor, last_ids)
return model_input
from importlib.util import find_spec
from typing import Dict, List, Optional
import torch
from vllm import envs
from vllm.model_executor.layers.rejection_sampler import _multinomial
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType
from vllm.sequence import VLLM_INVALID_TOKEN_ID
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
class SampleRecorder:
def __init__(self):
self.seq_id:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
last_sampler = SampleRecorder()
def get_last_sampler():
return last_sampler
class ZeroOverheadSampler(Sampler):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
last_sampler.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32)
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_id[i] = seq_group.seq_ids[0]
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: SampleResultsDictType = {}
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_n_in_batch = 1
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)
if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
last_sampler.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return (
maybe_deferred_args,
sampled_token_ids_tensor,
)
def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata,
sampling_metadata,
greedy_samples,
multinomial_samples,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
\ No newline at end of file
from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence
class ZeroOverheadSequence(Sequence):
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
self.effective_output_len : int = 0
def fix_last_token_id(self, token_id: int) -> None:
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
assert effect_offset < 0
self.data._output_token_ids[effect_offset] = token_id
if len(self.data._new_appended_tokens) >= effect_offset * -1:
self.data._new_appended_tokens[effect_offset] = token_id
self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len]
def zero_overhead_get_output_len(self) -> int:
return self.effective_output_len
def zero_overhead_get_last_token_id(self) -> int:
if self.effective_output_len == 0:
return self.data._prompt_token_ids[-1]
return self.data._output_token_ids[self.effective_output_len - 1]
def zero_overhead_get_len(self) -> int:
return self.effective_output_len + len(self.data._prompt_token_ids)
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.zero_overhead_get_output_token_ids()
output_len = self.zero_overhead_get_output_len()
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[self.effective_output_len - 1]
if num_new_tokens == 0:
return []
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
\ No newline at end of file
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker):
def __init__(self, max_model_len, get_tokenizer_for_seq):
super().__init__(max_model_len, get_tokenizer_for_seq)
def maybe_stop_sequence(
self,
seq: ZeroOverheadSequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id(self.zero_overhead)
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer):
def __init__(self, tokenizer_group):
super().__init__(tokenizer_group)
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length = seq.get_prompt_len() + seq.effective_output_len
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
\ No newline at end of file
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
import os
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def is_zero_overhead():
return zero_overhead
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment