"vscode:/vscode.git/clone" did not exist on "4570535ec41e9e6f808d4cd3a9a06c6928652dea"
Commit 54294854 authored by lizhigong's avatar lizhigong
Browse files

add v0 zero overhead

parent a0c212c0
......@@ -6,6 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
import zmq
from vllm import AsyncEngineArgs, SamplingParams
......@@ -79,6 +81,9 @@ class MQLLMEngine:
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
if is_zero_overhead():
self.engine = ZeroOverheadEngine(*args, **kwargs)
else:
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
......
......@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
logger = init_logger(__name__)
......@@ -244,6 +246,10 @@ class LLM:
)
# Create the Engine (autoselects V0 vs V1)
if is_zero_overhead():
self.llm_engine = ZeroOverheadEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else:
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine)
......
......@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler
from vllm.zero_overhead.v0.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
......@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
if is_zero_overhead():
return ZeroOverheadSampler()
return Sampler()
......
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
......@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder
from vllm.zero_overhead.v0.utils import is_zero_overhead
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
_builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
......
This diff is collapsed.
import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def __init__(self, runner, finished_requests_ids = None):
super().__init__(runner, finished_requests_ids)
self.req_ids = []
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.req_ids.clear()
return super().prepare(finished_requests_ids)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
return super().add_seq_group(seq_group_metadata)
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
if last_sampler.sampled_token_ids_tensor is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
last_ids = async_tensor_h2d(last_sampler.seq_id.tolist(), torch.long,
self.runner.device,
self.runner.pin_memory)
UpdateInputTokens(model_input.input_tokens, input_ids, last_sampler.sampled_token_ids_tensor, last_ids)
return model_input
from importlib.util import find_spec
from typing import Dict, List, Optional
import torch
from vllm import envs
from vllm.model_executor.layers.rejection_sampler import _multinomial
from vllm.model_executor.layers.sampler import MultinomialSamplesType, SampleMetadataType, \
SampleResultArgsType, SampleResultType, SampleResultsDictType, SampleReturnType, Sampler, \
SamplerOutput, _apply_min_p, _apply_min_tokens_penalty, _apply_top_k_top_p, _build_sampler_output, \
_modify_greedy_probs_inplace, _top_k_top_p_multinomial_with_flashinfer, get_logprobs
from vllm.model_executor.layers.utils import apply_penalties
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors, SequenceGroupToSample
from vllm.sampling_params import SamplingType
from vllm.sequence import VLLM_INVALID_TOKEN_ID
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
class SampleRecorder:
def __init__(self):
self.seq_id:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
last_sampler = SampleRecorder()
def get_last_sampler():
return last_sampler
class ZeroOverheadSampler(Sampler):
def __init__(self):
super().__init__()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
)
if self.include_gpu_probs_tensor:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert maybe_sampled_tokens_tensor is not None
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
return _build_sampler_output(
maybe_deferred_sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output,
logits=logits)
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.n
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
last_sampler.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32)
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_id[i] = seq_group.seq_ids[0]
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: SampleResultsDictType = {}
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
last_sampler.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1)
if modify_greedy_probs:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace(logprobs, probs,
long_sample_indices,
greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_n_in_batch = 1
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)
if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
last_sampler.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
sampling_metadata=sampling_metadata,
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return get_pythonized_sample_results(
maybe_deferred_args), sampled_token_ids_tensor
else:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return (
maybe_deferred_args,
sampled_token_ids_tensor,
)
def get_pythonized_sample_results(
sample_result_args: SampleResultArgsType) -> SampleResultType:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata,
sampling_metadata,
greedy_samples,
multinomial_samples,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.sample_results_dict,
)
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
\ No newline at end of file
from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence
class ZeroOverheadSequence(Sequence):
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
self.effective_output_len : int = 0
def fix_last_token_id(self, token_id: int) -> None:
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
assert effect_offset < 0
self.data._output_token_ids[effect_offset] = token_id
if len(self.data._new_appended_tokens) >= effect_offset * -1:
self.data._new_appended_tokens[effect_offset] = token_id
self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len]
def zero_overhead_get_output_len(self) -> int:
return self.effective_output_len
def zero_overhead_get_last_token_id(self) -> int:
if self.effective_output_len == 0:
return self.data._prompt_token_ids[-1]
return self.data._output_token_ids[self.effective_output_len - 1]
def zero_overhead_get_len(self) -> int:
return self.effective_output_len + len(self.data._prompt_token_ids)
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.zero_overhead_get_output_token_ids()
output_len = self.zero_overhead_get_output_len()
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[self.effective_output_len - 1]
if num_new_tokens == 0:
return []
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
\ No newline at end of file
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker):
def __init__(self, max_model_len, get_tokenizer_for_seq):
super().__init__(max_model_len, get_tokenizer_for_seq)
def maybe_stop_sequence(
self,
seq: ZeroOverheadSequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id(self.zero_overhead)
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer):
def __init__(self, tokenizer_group):
super().__init__(tokenizer_group)
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length = seq.get_prompt_len() + seq.effective_output_len
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
\ No newline at end of file
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
import os
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def is_zero_overhead():
return zero_overhead
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment