Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
......@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
......@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self, wte=None) -> None:
def __init__(self) -> None:
super().__init__()
self.wte = wte
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
......@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig, wte=None) -> None:
super().__init__(wte)
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
......@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.type_feature = config.img_processor.get('type_feature', 'patch')
def forward(self, input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""
# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values
# (batch_size, 2)
img_sizes = image_sizes
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
positions = torch.nonzero(input_ids == self.image_token_id)
select = False
target_dtype = self.img_projection[0].bias.dtype
if len(positions.tolist()) > 0:
# if self.use_hd_transform and img_sizes:
# img_embeds: (num_images, max_num_crops, 3, H, W)
# img_sizes: (num_images, 2).view(1, -1)
bs = img_embeds.shape[0]
# Nx(HW)xC
img_features = self.get_img_features(img_embeds.flatten(0, 1))
base_feat_height = base_feat_width = int(
img_features.shape[1]**0.5)
# bs x max_num_crops x (24x24) x C
img_features = img_features.view(
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
C = self.image_dim_out
H = base_feat_height
output_imgs = []
output_len = []
for _bs in range(bs):
h, w = img_sizes[_bs]
h = h // 336
w = w // 336
B_ = h * w
# 1 x (24x24) x 1024
global_img_feature = img_features[_bs, :1]
# 1 x 12 x 12 x 4096
glb_img = global_img_feature \
.reshape(1, H // 2, 2, H // 2, 2,C) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, H // 2, H // 2, 4 * C)
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN],
dim=2).reshape(1, -1, 4 * C)
# (max_num_crops-1) x (12x12) x C
sub_img = img_features[_bs, 1:]
# 16x574x1024
# get rid of padding sub_img
sub_img = sub_img[:B_]
sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, h * 12, w * 12, 4 * C)
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN],
dim=2).reshape(1, -1, 4 * C)
# (1, num_img_tokens, 1024*4)
# glb + sub
if self.hd_transform_order == 'glb_sub':
output_imgs.append(
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
elif self.hd_transform_order == 'sub_glb':
output_imgs.append(
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
output_len.append(temp_len)
num_img_tokens = output_len
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_dtype))
img_set_tensor.append(img_feature_proj)
select = True
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
hidden_states = self.wte(input_ids)
if select:
idx = 0
for i, cnt in enumerate(num_img_tokens):
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.dtype))
idx += cnt
return hidden_states.squeeze(0)
"""
process image and return vision embeddings.
pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
"""
num_images, num_crops, c, h, w = pixel_values.shape
pixel_values = pixel_values.flatten(0, 1)
img_features = self.get_img_features(pixel_values)
img_features = img_features.reshape(num_images, num_crops, -1,
self.image_dim_out)
image_features_proj = self.hd_feature_transform(
img_features, image_sizes)
return image_features_proj
def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
assert (
self.hd_transform_order == 'sub_glb'
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
if isinstance(self.img_projection, nn.Sequential):
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
else: # It's a single nn.Linear layer
target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype
global_image_features = image_features[:,
0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(
global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(
global_image_features_hd)
all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop
# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1:1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(
sub_image_features_hd)
# [sub features, separator, global features]
all_image_embeddings.append(
torch.cat([
sub_image_features_hd_newline.squeeze(
0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]))
image_features_proj = self.img_projection(
torch.stack(all_image_embeddings).to(target_device, target_dtype)
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
return image_features_proj
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
where h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2,
-1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2,
4 * C) # n_img, h_crop*12, w_crop*12, 4096
)
return image_features_hd
def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1,
-1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat(
[image_features_hd, newline_embeddings],
dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline
class Phi3VImagePixelInputs(TypedDict):
......@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, self.model.embed_tokens)
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.vision_embed_tokens(
input_ids, image_input["data"], image_input["image_sizes"])
vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
......
......@@ -45,10 +45,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
......@@ -392,18 +392,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
from typing import Callable, Dict, List, Tuple
from typing import Dict, List, Protocol, Tuple
import torch
from torch.func import functional_call
from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available
def merge_vision_embeddings(input_ids: torch.Tensor,
......@@ -43,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
......@@ -52,8 +63,74 @@ class PPMissingLayer(torch.nn.Identity):
super().__init__()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
state_dict: Dict[str, torch.Tensor] = module.state_dict()
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
......@@ -64,9 +141,10 @@ def make_layers(
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules
......
......@@ -8,7 +8,7 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
make_tensor_with_pad, maybe_expand_dim)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
......@@ -86,6 +86,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def __init__(
......@@ -94,11 +100,15 @@ class SamplingMetadata:
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int,
skip_sampler_cpu_output: bool = False,
reuse_sampling_tensors: bool = False,
) -> None:
self.seq_groups = seq_groups
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = num_prompts
self.skip_sampler_cpu_output = skip_sampler_cpu_output
self.reuse_sampling_tensors = reuse_sampling_tensors
@staticmethod
def prepare(
......@@ -455,18 +465,24 @@ class SamplingTensors:
do_penalties = prompt_tokens or output_tokens
if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
output_t = empty_tensor
temperatures_t = torch.tensor(
temperatures,
......@@ -516,22 +532,6 @@ class SamplingTensors:
dtype=torch.long,
pin_memory=pin_memory,
)
if do_penalties:
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
......@@ -554,16 +554,6 @@ class SamplingTensors:
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
......@@ -575,8 +565,8 @@ class SamplingTensors:
non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True),
prompt_tokens=prompt_tokens_gpu,
output_tokens=output_tokens_gpu,
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
......
import base64
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
from typing import Union
import aiohttp
import requests
from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION
def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")
def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
def _load_image_from_bytes(b: bytes):
......@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
headers = _get_request_headers()
with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image_raw = global_http_connection.get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
......@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return image.convert(image_mode)
class ImageFetchAiohttp:
aiohttp_client: Optional[aiohttp.ClientSession] = None
@classmethod
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)
return cls.aiohttp_client
@classmethod
async def fetch_image(
cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
client = cls.get_aiohttp_client()
headers = _get_request_headers()
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = _load_image_from_bytes(image_raw)
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
return image.convert(image_mode)
return image.convert(image_mode)
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
image = await async_fetch_image(image_url)
return {"image": image}
......
......@@ -2,7 +2,9 @@ from typing import Optional
import torch
from .interface import Platform, PlatformEnum
from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform]
......@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = None
current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
import enum
from typing import Tuple
import torch
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
UNSPECIFIED = enum.auto()
class Platform:
......@@ -16,6 +20,23 @@ class Platform:
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
from typing import Tuple
import torch
from .interface import Platform, PlatformEnum
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise RuntimeError("TPU does not have device capability.")
@staticmethod
def inference_mode():
return torch.no_grad()
......@@ -8,7 +8,6 @@ import torch
from pydantic import Field
from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -189,18 +188,6 @@ class SamplingParams:
self._verify_args()
if self.use_beam_search:
# Lazy import to avoid circular imports.
from vllm.usage.usage_lib import set_runtime_usage_data
set_runtime_usage_data("use_beam_search", True)
if not envs.VLLM_NO_DEPRECATION_WARNING:
logger.warning(
"[IMPORTANT] We plan to discontinue the support for beam "
"search in the next major release. Please refer to "
"https://github.com/vllm-project/vllm/issues/6226 for "
"more information. Set VLLM_NO_DEPRECATION_WARNING=1 to "
"suppress this warning.")
self._verify_beam_search()
else:
self._verify_non_beam_search()
......
......@@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch
......@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
......@@ -457,24 +458,25 @@ class SequenceGroup:
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self._first_seq = next(iter(self.seqs_dict.values()))
@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
return self._first_seq.prompt
@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt_token_ids
return self._first_seq.prompt_token_ids
@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
return self._first_seq.multi_modal_data
@property
def lora_int_id(self) -> int:
......
......@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
......@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
......@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
},
lora_request=None,
token_chunk_size=1,
state=state,
)
def _split_scoring_output(
......
......@@ -2,17 +2,33 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
......@@ -21,18 +37,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
......@@ -71,51 +78,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states=return_hidden_states,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self.cached_seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
if num_seqs != num_queries:
assert num_seqs > num_queries
assert attn_metadata.use_cuda_graph
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
assert attn_metadata.num_decode_tokens == num_seqs
assert attn_metadata.slot_mapping.shape == (num_seqs, )
assert len(attn_metadata.seq_lens) == num_seqs
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
assert attn_metadata.max_query_len == 1
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
def update_model_input(
assert attn_metadata.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not allow_gpu_advance_step:
return False
# Append the output token to the sequence data.
assert self.cached_seq_group_metadata_list is not None
for seq_group_metadata, sequence_group_outputs in zip(
self.cached_seq_group_metadata_list, last_output.outputs):
seq_group_metadata.is_prompt = False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
for seq_output in sequence_group_outputs.samples:
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
return False
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# TODO: Add support for LORA
if self.lora_config:
return False
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
# TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config:
return False
return self.prepare_model_input(self.cached_seq_group_metadata_list)
return True
@torch.inference_mode()
def execute_model(
......@@ -125,42 +237,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Detect exec mode
assert model_input.attn_metadata is not None
use_cuda_graph = False
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
# Attn attr defines if we use cuda graphs
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
# Get model
if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
else:
model_executable = self.model
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......@@ -181,8 +337,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata=model_input.sampling_metadata,
))
# Prepare the inputs for the next step.
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs
......@@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor
# A flag to mark that there's no available proposals
no_proposals: bool = False
def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "
......
......@@ -145,6 +145,10 @@ class AsyncMetricsCollector:
"""
ready_event.synchronize()
# update time of last collection
self._last_metrics_collect_time = self._timer()
accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
......
......@@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
)
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for multi_step_worker
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
@torch.inference_mode()
......@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner):
if isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=expanded_request)
......@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
],
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]
......
......@@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
Current NGramWorker only implements prompt lookup decoding,
and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals.
"""
......@@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self.device = torch.device(f"cuda:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current only support Top1Proposer
# Current NGramWorker only supports Top1Proposer
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
device=self.device,
......
......@@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
) -> Tuple[Optional[List[SamplerOutput]], bool]:
raise NotImplementedError
def set_include_gpu_probs_tensor(self):
def set_include_gpu_probs_tensor(self) -> None:
"""Implementation optional"""
pass
......
......@@ -9,12 +9,12 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
......@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
......@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
......@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha)
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
return spec_decode_worker
......@@ -107,8 +115,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
......@@ -133,6 +143,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
......@@ -155,18 +167,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
def __init__(
self,
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
):
"""
Create a SpecDecodeWorker.
......@@ -183,15 +200,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
......@@ -206,12 +230,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initiazliation.
# Lazy initialization.
self.scorer: SpeculativeScorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
......@@ -347,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
......@@ -381,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
"""
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids):
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
))
return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
......@@ -407,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
sampler_output)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return [sampler_output]
return [sampler_output_to_return]
def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used
......@@ -461,11 +526,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
......@@ -521,11 +590,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
......@@ -569,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
......@@ -596,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list: List[SamplerOutput] = []
......@@ -645,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
def _create_dummy_logprob_lists(
self,
batch_size: int,
num_steps: int,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step = [[-1] * batch_size
for _ in range(num_steps)]
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
for _ in range(num_steps)]
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _create_logprob_lists_from_tensors(
self,
target_logprobs_by_step: torch.Tensor,
accepted_token_ids_by_step: torch.Tensor,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step_tensor,
accepted_token_id_logprobs_by_step_tensor
) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(topk_logprobs_by_step_tensor,
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
k=num_top_k,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step_tensor.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step_tensor.tolist())
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
"""
Removes the finished requests and their associated sequence ids from
......
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
class TargetModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def __init__(self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input.sampling_metadata.skip_sampler_cpu_output = (
self.disable_logprobs)
return model_input
......@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
no_proposals=maybe_sampler_output is None)
return proposals
......@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None
......@@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment