Commit ca796e19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.1' into v0.8.1-ori

parents e983c804 61c7a1b8
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
......@@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None,
) -> Tuple[List[str], str]:
) -> Tuple[str, List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
......@@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return weight_files, pattern
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
......@@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
......@@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files(
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors":
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
......@@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors"
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
......
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, encode_tokens)
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
encode_tokens, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsV0Only)
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
......@@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(num_crops_total, num_channels, height, width)`
Shape: `(num_patches_total, num_channels, height, width)`
`num_crops_total` is the total number of crops
`num_patches_total` is the total number of patches
over each image over each prompt in the batch.
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images,)`"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
......@@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
......@@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
......@@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> str:
) -> PromptUpdateDetails:
if processor is None:
processor = self.get_hf_processor()
......@@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
return image_text.replace(image_token, processor.full_image_sequence)
repl_full = image_text.replace(image_token,
processor.full_image_sequence)
repl_features = repl_full.strip("\n")
return PromptUpdateDetails(full=repl_full, features=repl_features)
def get_num_image_tokens(
self,
......@@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
......@@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
num_images=num_images)
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return ProcessorInputs(
prompt_text=" ".join([image_token] * num_images),
prompt_text=image_token * num_images,
mm_data=mm_data,
)
......@@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
......@@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
)
]
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
token_ids = replace_token_matches(
token_ids,
[newline_1, newline_2],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_1],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_2],
[newline_4],
)
return token_ids
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
if tok == newline_4:
return [newline_2, newline_2]
return [tok]
repl_token_ids = list[int]()
repl_orig_idxs = list[int]()
for orig_idx, orig_tok in enumerate(new_token_ids):
repl_toks = get_repl_toks(orig_tok)
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
return {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
) for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3MultiModalProjector(nn.Module):
......@@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA, SupportsV0Only):
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def sampler(self):
return self.language_model.sampler
......@@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
......@@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops values. "
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_crops=num_crops,
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
......@@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(vision_outputs)
image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist())
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(self,
......@@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
......@@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
kwargs = self.prepare_attn_masks(
input_ids,
positions,
mask_dtype=vision_embeddings.dtype,
**kwargs)
mask_dtype=self.dtype,
**kwargs,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
......
......@@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
......@@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class LlavaImagePixelInputs(TypedDict):
......@@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_patches: Union[torch.Tensor, list[torch.Tensor]]
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
......@@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_patches` to get per-image masks.
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_patches"] = num_patches
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
......@@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
......@@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
......@@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......@@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_patches"],
image_input["num_embeds"],
image_input["embed_is_patch"],
))
......@@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
......
......@@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
for idx, decoder_layer in enumerate(self.layers):
if idx in self.cross_attention_layers:
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
......@@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
else:
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
......@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
\ No newline at end of file
return mask
......@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
from typing import List, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
......@@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
......@@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
......@@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
from .vision import select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
......@@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if multimodal_embeddings is not None:
assert self.img_patch_id is not None
# Extract the patch tokens scattered in _get_mm_embeds
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.img_patch_id,
)
return inputs_embeds
......
......@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
scatter_patch_features, select_patch_features)
try:
from xformers import ops as xops
......@@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_patches: Union[torch.Tensor, list[torch.Tensor]]
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
......@@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_patches = list[int]()
images_num_embeds = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
......@@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_patches.append(len(image_tokens))
images_num_embeds.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_patches": torch.tensor(images_num_patches),
"num_embeds": torch.tensor(images_num_embeds),
}
......@@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -365,16 +365,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
if self.vision_args.add_pre_mm_projector_layer_norm:
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
eps=1e-5)
if self.vision_args.mm_projector_id == PATCH_MERGE:
self.patch_merger = PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
......@@ -403,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
num_embeds=num_embeds,
)
def _process_image_input(
......@@ -442,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......@@ -481,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return image_features
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_patches"],
image_input["num_embeds"],
image_input["embed_is_patch"],
))
......@@ -494,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
select_patch_features(multimodal_embeddings),
self.vision_args.image_token_id,
)
return inputs_embeds
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
from transformers import PretrainedConfig
......@@ -9,9 +9,12 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from .interfaces import MultiModalEmbeddings
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
......@@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image)
def select_patch_features(
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
return cast(MultiModalEmbeddings, selected_features)
......@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
......@@ -511,8 +511,35 @@ def iter_token_matches(
start_idx += 1
def replace_token_matches(
token_ids: list[int],
match_ids: list[int],
new_ids: list[int],
) -> list[int]:
"""
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
with :code:`new_ids`.
Note that empty matches are ignored.
"""
out_seqs = list[list[int]]()
prev_end_idx = 0
for match in iter_token_matches(token_ids, match_ids):
start_idx = match.start_idx
end_idx = match.end_idx
out_seqs.append(token_ids[prev_end_idx:start_idx])
out_seqs.append(new_ids)
prev_end_idx = end_idx
out_seqs.append(token_ids[prev_end_idx:])
return flatten_2d_lists(out_seqs)
@dataclass(repr=False)
class _PromptTargetMatch(ABC):
class PromptTargetMatch(ABC):
_origin: BoundPromptUpdate
@property
......@@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
@dataclass(repr=False)
class _PromptTargetIndexMatch(_PromptTargetMatch):
class _PromptTargetIndexMatch(PromptTargetMatch):
match_idx: int
@property
......@@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
@dataclass(repr=False)
class _PromptTargetTokenMatch(_PromptTargetMatch):
class _PromptTargetTokenMatch(PromptTargetMatch):
match: _TokenMatch
@property
......@@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
@dataclass(repr=False)
class _PromptTargetTextMatch(_PromptTargetMatch):
class _PromptTargetTextMatch(PromptTargetMatch):
match: re.Match[str]
@property
......@@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
def find_token_matches(
prompt: list[int],
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
) -> Sequence[PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
......@@ -620,7 +647,7 @@ def find_token_matches(
def find_text_matches(
prompt: str,
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
) -> Sequence[PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
......@@ -645,15 +672,15 @@ def find_text_matches(
def _resolve_matches(
prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
) -> list[_PromptTargetMatch]:
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
) -> list[PromptTargetMatch]:
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
for match in matches:
for idx in range(match.start_idx, match.end_idx):
......@@ -669,7 +696,7 @@ def _resolve_matches(
def _apply_matches(
prompt: _S,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[_S]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
......@@ -718,7 +745,7 @@ def _apply_matches(
def apply_token_matches(
prompt: list[int],
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
......@@ -732,7 +759,7 @@ def apply_token_matches(
def apply_text_matches(
prompt: str,
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
......@@ -826,33 +853,62 @@ class ProcessingCache:
@staticmethod
def get_lru_cache(
capacity_gb: int,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
def get_size(leaf: object) -> int:
def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors
return leaf.nbytes
return sys.getsizeof(leaf)
return LRUCache[str, _V](
GiB_bytes * capacity_gb,
getsizeof=lambda x: json_reduce_leaves(
def get_item_size(
value: Union[MultiModalKwargs, MultiModalKwargsItem,
Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_size, x),
),
)
json_map_leaves(get_leaf_size, value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)
return size
def __init__(self, capacity_gb: int) -> None:
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
def __init__(
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__()
# DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
......@@ -863,6 +919,9 @@ class ProcessingCache:
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get(
self,
......@@ -1055,14 +1114,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
Notes:
- You should not assume that HF processor always performs prompt
updates: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
- The update information returned by this method is also used to
determine the placeholder token positions for each multi-modal
item.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
raise NotImplementedError
......@@ -1357,6 +1416,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it = (update.bind(tokenizer) for update in prompt_updates)
return dict(full_groupby_modality(it))
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
return apply_token_matches(prompt, mm_matches, mm_item_counts)
def _apply_text_matches(
self,
prompt: str,
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
return apply_text_matches(prompt, mm_matches, mm_item_counts)
def _apply_prompt_updates(
self,
token_ids: list[int],
......@@ -1388,7 +1463,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items()
): # yapf: disable
token_ids = apply_token_matches(
token_ids = self._apply_token_matches(
token_ids,
mm_token_matches,
mm_item_counts,
......@@ -1406,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
modality: find_text_matches(text, updates)
for modality, updates in mm_prompt_updates.items()
}
text = apply_text_matches(
text = self._apply_text_matches(
text,
mm_text_matches,
mm_item_counts,
......
......@@ -73,7 +73,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=0)
image = Image.new("RGB", (width, height), color=255)
return [image] * num_images
def _get_dummy_videos(
......@@ -84,7 +84,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.zeros((num_frames, width, height, 3))
video = np.full((num_frames, width, height, 3), 255)
return [video] * num_videos
......
......@@ -143,10 +143,6 @@ def make_mistral_chat_completion_request(
if last_message["role"] == "assistant":
last_message["prefix"] = True
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
# mistral-common requires AssistantMessage content to be string [1].
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
......
......@@ -2,11 +2,4 @@
from vllm.triton_utils.importing import HAS_TRITON
__all__ = ["HAS_TRITON"]
if HAS_TRITON:
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
__all__ += ["maybe_set_triton_cache_manager"]
__all__ = ["HAS_TRITON"]
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
import os
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)
from vllm.logger import init_logger
logger = init_logger(__name__)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager
class CustomCacheManager(FileCacheManager):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
......@@ -119,16 +119,21 @@ class Processor:
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
supported_backends = ["xgrammar"]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
if params.guided_decoding.backend != engine_level_backend:
raise ValueError("Request-level structured output backend "
"must match engine-level backend. "
f"{params.guided_decoding.backend}"
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
......
......@@ -46,7 +46,7 @@ class SamplerOutput:
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding.
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors]
......
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
def compiled_softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor] = 1.0,
) -> torch.Tensor:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch._dynamo.mark_dynamic(logits, index=0)
if isinstance(temperature, torch.Tensor):
torch._dynamo.mark_dynamic(temperature, index=0)
return _softmax(logits, temperature)
@torch.compile
def _softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor],
) -> torch.Tensor:
logits = logits / temperature
return torch.softmax(logits, dim=-1, dtype=torch.float32)
......@@ -3,25 +3,32 @@ from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import random_sample
from vllm.v1.sample.ops.utils import compiled_softmax
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
class RejectionSampler(nn.Module):
"""
The implementation strictly follows the algorithm described in
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
......@@ -31,48 +38,42 @@ class RejectionSampler(nn.Module):
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self):
super().__init__()
def forward(
self,
draft_token_ids: list[list[int]],
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor, # [batch_size, 1]
target_probs: torch.Tensor, # [num_total_tokens, vocab_size]
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
draft_token_ids (List[List[int]]):
A 2D list of token IDs for each request in the batch.
Each request might have different number of draft tokens.
It may also contain empty lists for requests that have
no draft tokens.
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[batch_size, max_spec_len, vocab_size]. Can be None if
probabilities are not provided, which is the case for
ngram spec decode.
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
target_probs (torch.Tensor):
Target model probability distribution.
Shape is [num_total_tokens, vocab_size]. num_total_tokens
is the total number of tokens from all requests. Here,
probabilities from different requests are flattened into
a single tensor because this is the shape of the output
logits.
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
......@@ -80,268 +81,481 @@ class RejectionSampler(nn.Module):
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Create one-hot tensor for draft token ids.
# This is used for ngram where we don't have draft_probs.
if draft_probs is None and not sampling_metadata.all_greedy:
vocab_size = target_probs.size(-1)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
sample_lens = [len(x) + 1 for x in draft_token_ids]
target_probs = _convert_2d_probs(target_probs, sample_lens)
return self.forward_native(draft_token_ids_tensor, draft_probs,
bonus_token_ids_tensor, target_probs,
sampling_metadata)
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
draft_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len, vocab_size]
draft_probs: Optional[torch.Tensor],
bonus_token_ids_tensor: torch.Tensor,
# [batch_size, max_spec_len + 1, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
target_token_ids_tensor = target_probs.argmax(dim=-1)
accept_mask = (target_token_ids_tensor[:, :-1] ==
draft_token_ids_tensor).cumprod(dim=1)
# Identify valid positions (non-padding).
valid_mask = target_token_ids_tensor != INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask = torch.cat([
accept_mask,
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
],
dim=1).to(torch.bool) & valid_mask
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
# Figure out which rows actually contain at least one zero.
rows_with_zero = zeros_mask.any(dim=1)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
output_token_ids = target_token_ids_tensor
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
else:
# Reference: https://arxiv.org/pdf/2211.17192
# 1. Extract the probabilities of the draft tokens.
# [batch_size, max_spec_len]
batch_size = draft_token_ids_tensor.size(0)
max_spec_len = draft_token_ids_tensor.size(1)
invalid_idx = draft_token_ids_tensor == INVALID_TOKEN_ID
draft_token_ids_tensor[invalid_idx] = 0
assert draft_probs is not None
draft_token_probs = draft_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
target_token_probs = target_probs.gather(
dim=-1, index=draft_token_ids_tensor.unsqueeze(-1)).squeeze(-1)
# Force the probabilities of invalid tokens to inf
# so that they are not accepted.
draft_token_probs[invalid_idx] = float('inf')
# 2. Generate uniform samples.
# [batch_size, max_spec_len + 1]
uniform_samples = _create_uniform_samples(
sampling_metadata.generators, batch_size, max_spec_len,
target_probs.device)
# 3. Accept or reject the samples.
# [batch_size, max_spec_len]
# If the draft token probabilities are 0, set them to the smallest
# positive normal value representable by float32.
safe_draft_probs = torch.where(draft_token_probs > 0,
draft_token_probs,
torch.finfo(torch.float32).tiny)
accepted = uniform_samples <= target_token_probs / safe_draft_probs
accept_mask = accepted.cumprod(dim=1)
# Set the token ids to the draft token ids if accepted, otherwise
# set them to INVALID_TOKEN_ID.
accepted_token_ids = (draft_token_ids_tensor * accept_mask +
INVALID_TOKEN_ID * (1 - accept_mask))
# 4. Adjust the distribution for the recovered tokens.
# Clamp the bonus probabilities to the smallest positive normal
# value representable by float32.
bonus_prob = torch.clamp(target_probs[:, :-1, :] - draft_probs,
min=torch.finfo(torch.float32).tiny)
normalized_bonus_prob = bonus_prob / bonus_prob.sum(dim=-1,
keepdim=True)
# 5. Sample recovered token ids.
recovered_token_ids = random_sample(
normalized_bonus_prob,
sampling_metadata.generators).reshape(batch_size, max_spec_len)
# 6. Get the final output token ids.
# output_token_ids = accepted_token_ids +
# recovered_token_ids +
# bonus_token_id
recovered_bonus_token_ids = torch.cat(
[recovered_token_ids, bonus_token_ids_tensor], dim=1)
# Generate mask with bonus tokens.
generate_mask = torch.cat([
accept_mask,
torch.zeros(batch_size, 1, device=accept_mask.device)
],
dim=1).to(torch.bool)
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
output_token_ids = torch.cat([
accepted_token_ids,
torch.full((batch_size, 1),
fill_value=INVALID_TOKEN_ID,
device=accept_mask.device)
],
dim=1)
output_token_ids[torch.arange(batch_size),
first_zero_idx] = recovered_bonus_token_ids[
torch.arange(batch_size), first_zero_idx]
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_lens: list[int]) -> torch.Tensor:
"""
Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. Note that division by
temperature is not performed inplace to preserve the original logits
tensor, which will be used by the original sampler to get bonus tokens.
Args:
logits: Input logits tensor to be converted to probabilities
sampling_metadata: Metadata containing sampling parameters such
as temperature and whether greedy sampling is used
sample_lens: List of sample lengths used for repeating
temperature values
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits
"""
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs
def rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.empty(
(batch_size, max_spec_len + 1),
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
num_warps=1,
)
if sampling_metadata.all_greedy:
return logits
assert sampling_metadata.temperature is not None
# We should optimize the following code as
# it will cause CPU -> GPU synchronization.
temperature = torch.repeat_interleave(
sampling_metadata.temperature,
torch.tensor(sample_lens,
device=sampling_metadata.temperature.device))
temperature = temperature.unsqueeze(dim=1)
logits = logits / temperature
return logits.softmax(dim=-1, dtype=torch.float32)
def _create_greedy_token_probs(
token_ids: torch.Tensor,
vocab_size: int,
out_device: torch.device,
return output_token_ids
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
num_warps=1,
)
return output_token_ids
def compute_probs(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape
token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)
# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0
token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())
return token_probs
def _convert_2d_probs(
probs: torch.Tensor, # [num_total_tokens, vocab_size]
sample_lens: list[int]) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be converted to probabilities.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits.
"""
Converts a 2D tensor of probabilities to a 3D tensor with padding.
[num_total_tokens, vocab_size] ->
[batch_size, max_spec_len + 1, vocab_size]
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
if sampling_metadata.all_greedy:
return logits
num_tokens = logits.shape[0]
batch_size = cu_num_draft_tokens.shape[0]
expanded_temperature = torch.empty(
(num_tokens, 1),
dtype=torch.float32,
device=logits.device,
)
expand_kernel[(batch_size, )](
expanded_temperature,
sampling_metadata.temperature,
cu_num_draft_tokens,
GREEDY_TEMPERATURE, # replace_from
1, # replace_to
MAX_NUM_TOKENS=MAX_SPEC_LEN,
num_warps=1,
)
output_prob = compiled_softmax(logits, expanded_temperature)
return output_prob
def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""
cumulative_lens = torch.cumsum(torch.tensor(sample_lens,
device=probs.device),
dim=0)
split_indices = cumulative_lens[:-1].tolist() # Exclude last index
# Split into chunks without loops
chunks = torch.tensor_split(probs, split_indices, dim=0)
# Pad all sequences to maximum length
padded_probs = pad_sequence(chunks, batch_first=True, padding_value=0.0)
return padded_probs
def _create_uniform_samples(seeded_seqs: dict[int, torch.Generator],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
Generates a batch of uniform random samples, with optional seeding
if available.
This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
the requests with their own seeds will use the provided `torch.Generator`
for reproducibility. The samples for the other requests will be generated
without a seed.
Args:
num_tokens : int
Total number of tokens.
num_draft_tokens : List[List[int]]
Number of draft tokens per request.
generators : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k)` containing uniform
random values in the range [0, 1).
"""
uniform_rand = torch.rand(batch_size,
k,
dtype=torch.float32,
device=device)
# Apply seeded generators only where needed
if seeded_seqs:
for idx, generator in seeded_seqs.items():
uniform_rand[idx].uniform_(0, 1, generator=generator)
return uniform_rand
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
device=device,
)
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
triton.next_power_of_2(vocab_size),
IS_NGRAM=draft_probs is None,
)
return recovered_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
if is_greedy_ptr is None:
is_greedy = True
else:
is_greedy = tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset,
src_val,
mask=offset < num_tokens)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
IS_NGRAM: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if IS_NGRAM:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if IS_NGRAM:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment