Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
...@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip) input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, ...@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
class Phi3ImageEmbeddingBase(nn.Module): class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self, wte=None) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.wte = wte
self.layer_idx: int self.layer_idx: int
self.type_feature: str self.type_feature: str
self.img_processor: CLIPVisionModel self.img_processor: CLIPVisionModel
...@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module): ...@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform.""" """Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig, wte=None) -> None: def __init__(self, config: PretrainedConfig) -> None:
super().__init__(wte) super().__init__()
self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
...@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
nn.Linear(dim_projection, dim_projection)]) nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers) self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.type_feature = config.img_processor.get('type_feature', 'patch') self.type_feature = config.img_processor.get('type_feature', 'patch')
def forward(self, input_ids: torch.LongTensor, def forward(self, pixel_values: torch.FloatTensor,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor) -> torch.FloatTensor: image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings.""" """
process image and return vision embeddings.
# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
# (batch_size, 2) """
img_sizes = image_sizes num_images, num_crops, c, h, w = pixel_values.shape
pixel_values = pixel_values.flatten(0, 1)
input_shape = input_ids.size() img_features = self.get_img_features(pixel_values)
input_ids = input_ids.view(-1, input_shape[-1]) img_features = img_features.reshape(num_images, num_crops, -1,
self.image_dim_out)
positions = torch.nonzero(input_ids == self.image_token_id) image_features_proj = self.hd_feature_transform(
img_features, image_sizes)
select = False return image_features_proj
target_dtype = self.img_projection[0].bias.dtype def hd_feature_transform(self, image_features, image_sizes):
"""
if len(positions.tolist()) > 0: image_features: (num_images, num_crops+1, 24*24, 1024)
# if self.use_hd_transform and img_sizes: """
# img_embeds: (num_images, max_num_crops, 3, H, W) assert (
# img_sizes: (num_images, 2).view(1, -1) self.hd_transform_order == 'sub_glb'
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
bs = img_embeds.shape[0] if isinstance(self.img_projection, nn.Sequential):
# Nx(HW)xC target_device = self.img_projection[0].bias.device
img_features = self.get_img_features(img_embeds.flatten(0, 1)) target_dtype = self.img_projection[0].bias.dtype
base_feat_height = base_feat_width = int( else: # It's a single nn.Linear layer
img_features.shape[1]**0.5) target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype
# bs x max_num_crops x (24x24) x C
img_features = img_features.view( global_image_features = image_features[:,
bs, -1, base_feat_height * base_feat_width, self.image_dim_out) 0] # (num_images, 24*24, 1024)
C = self.image_dim_out # global feature can be viewed as a special HD case with num_crops 1x1
H = base_feat_height global_image_features_hd = self.reshape_hd_patches_2x2merge(
global_image_features, 1, 1)
output_imgs = [] global_image_features_hd_newline = self.add_image_newline(
output_len = [] global_image_features_hd)
for _bs in range(bs): all_image_embeddings = []
h, w = img_sizes[_bs] # need a for loop to process each image because of different image sizes
h = h // 336 # (patch arrangement is different for each image)
w = w // 336 for i, img_size in enumerate(image_sizes):
B_ = h * w h, w = img_size
h_crop = h // 336
# 1 x (24x24) x 1024 w_crop = w // 336
global_img_feature = img_features[_bs, :1] num_crops = h_crop * w_crop
# 1 x 12 x 12 x 4096 # NOTE: real num_crops is padded
glb_img = global_img_feature \ # (num_crops, 24*24, 1024)
.reshape(1, H // 2, 2, H // 2, 2,C) \ sub_image_features = image_features[i, 1:1 + num_crops]
.permute(0, 1, 3, 2, 4, 5) \ sub_image_features_hd = self.reshape_hd_patches_2x2merge(
.reshape(1, H // 2, H // 2, 4 * C) sub_image_features, h_crop, w_crop)
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) sub_image_features_hd_newline = self.add_image_newline(
sub_image_features_hd)
# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN], # [sub features, separator, global features]
dim=2).reshape(1, -1, 4 * C) all_image_embeddings.append(
torch.cat([
# (max_num_crops-1) x (12x12) x C sub_image_features_hd_newline.squeeze(
sub_img = img_features[_bs, 1:] 0), # (h_crop*12*(w_crop*12+1), 4096)
# 16x574x1024 self.glb_GN.squeeze(0),
# get rid of padding sub_img global_image_features_hd_newline[i],
sub_img = sub_img[:B_] ]))
sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \ image_features_proj = self.img_projection(
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) torch.stack(all_image_embeddings).to(target_device, target_dtype)
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \ ) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, h * 12, w * 12, 4 * C) return image_features_proj
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN], def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
dim=2).reshape(1, -1, 4 * C) """
# (1, num_img_tokens, 1024*4) image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
# glb + sub where h_crop*w_crop == num_crops
if self.hd_transform_order == 'glb_sub': """
output_imgs.append( N, L, C = image_features.shape
torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
elif self.hd_transform_order == 'sub_glb': num_images = N // (h_crop * w_crop)
output_imgs.append( H = int(L**0.5)
torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
output_len.append(temp_len) .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
num_img_tokens = output_len .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
img_set_tensor = [] -1) # n_img, h_crop, w_crop, 12, 12, 4096
for _output_img in output_imgs: .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
img_feature_proj = self.img_projection( .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
_output_img.to(target_dtype)) 4 * C) # n_img, h_crop*12, w_crop*12, 4096
img_set_tensor.append(img_feature_proj) )
select = True return image_features_hd
input_ids.clamp_min_(0).clamp_max_(self.vocab_size) def add_image_newline(self, image_features_hd):
"""
hidden_states = self.wte(input_ids) image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
if select: """
idx = 0 num_images, h, w, hid_dim = image_features_hd.shape
for i, cnt in enumerate(num_img_tokens): # add the newline token to the HD image feature patches
hidden_states[positions[idx, 0], newline_embeddings = self.sub_GN.expand(num_images, h, -1,
positions[idx, 1]:positions[idx, 1] + -1) # (n_img, h, 1, hid_dim)
cnt] = (img_set_tensor[i].to( image_features_hd_newline = torch.cat(
hidden_states.dtype)) [image_features_hd, newline_embeddings],
idx += cnt dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline
return hidden_states.squeeze(0)
class Phi3VImagePixelInputs(TypedDict): class Phi3VImagePixelInputs(TypedDict):
...@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID
self.model = LlamaModel(config, cache_config, quant_config) self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(config)
config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
...@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
inputs_embeds = self.vision_embed_tokens( vision_embeddings = self.vision_embed_tokens(
input_ids, image_input["data"], image_input["image_sizes"]) image_input["data"], image_input["image_sizes"])
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
......
...@@ -45,10 +45,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -45,10 +45,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -392,18 +392,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -392,18 +392,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): name = maybe_remap_kv_scale_name(name, params_dict)
remapped_kv_scale_name = name.replace( if name is None:
".kv_scale", ".attn.kv_scale") continue
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
from typing import Callable, Dict, List, Tuple from typing import Dict, List, Protocol, Tuple
import torch import torch
from torch.func import functional_call
from vllm.multimodal import BatchedTensors from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available
def merge_vision_embeddings(input_ids: torch.Tensor, def merge_vision_embeddings(input_ids: torch.Tensor,
...@@ -43,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor, ...@@ -43,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity): class PPMissingLayer(torch.nn.Identity):
""" """
A placeholder layer for missing layers in a pipeline parallel model. A placeholder layer for missing layers in a pipeline parallel model.
...@@ -52,8 +63,74 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -52,8 +63,74 @@ class PPMissingLayer(torch.nn.Identity):
super().__init__() super().__init__()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
state_dict: Dict[str, torch.Tensor] = module.state_dict()
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
def make_layers( def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking """Make a list of layers with the given layer function, taking
pipeline parallelism into account. pipeline parallelism into account.
...@@ -64,9 +141,10 @@ def make_layers( ...@@ -64,9 +141,10 @@ def make_layers(
get_pp_group().rank_in_group, get_pp_group().rank_in_group,
get_pp_group().world_size) get_pp_group().world_size)
modules = torch.nn.ModuleList( modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [PPMissingLayer() for _ in range(start_layer)] + [
[layer_fn() for _ in range(start_layer, end_layer)] + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules return start_layer, end_layer, modules
......
...@@ -8,7 +8,7 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits ...@@ -8,7 +8,7 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim) make_tensor_with_pad, maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
...@@ -86,6 +86,12 @@ class SamplingMetadata: ...@@ -86,6 +86,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit), The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit). and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups. num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
""" """
def __init__( def __init__(
...@@ -94,11 +100,15 @@ class SamplingMetadata: ...@@ -94,11 +100,15 @@ class SamplingMetadata:
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor], categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int, num_prompts: int,
skip_sampler_cpu_output: bool = False,
reuse_sampling_tensors: bool = False,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = num_prompts self.num_prompts = num_prompts
self.skip_sampler_cpu_output = skip_sampler_cpu_output
self.reuse_sampling_tensors = reuse_sampling_tensors
@staticmethod @staticmethod
def prepare( def prepare(
...@@ -455,18 +465,24 @@ class SamplingTensors: ...@@ -455,18 +465,24 @@ class SamplingTensors:
do_penalties = prompt_tokens or output_tokens do_penalties = prompt_tokens or output_tokens
if do_penalties: if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens], prompt_t = make_tensor_with_pad(
default=0) prompt_tokens,
prompt_padded_tokens = [ vocab_size,
tokens + [vocab_size] * (prompt_max_len - len(tokens)) device="cpu",
for tokens in prompt_tokens dtype=torch.int64,
] pin_memory=pin_memory,
output_max_len = max([len(tokens) for tokens in output_tokens], )
default=0) output_t = make_tensor_with_pad(
output_padded_tokens = [ output_tokens,
tokens + [vocab_size] * (output_max_len - len(tokens)) vocab_size,
for tokens in output_tokens device="cpu",
] dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
output_t = empty_tensor
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
...@@ -516,22 +532,6 @@ class SamplingTensors: ...@@ -516,22 +532,6 @@ class SamplingTensors:
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if do_penalties:
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to # need to transpose and make contiguous to
# copy the tensor correctly. # copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size] # [batch_size, n_seeds] -> [n_seeds, batch_size]
...@@ -554,16 +554,6 @@ class SamplingTensors: ...@@ -554,16 +554,6 @@ class SamplingTensors:
extra_seeds_gpu = None extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
...@@ -575,8 +565,8 @@ class SamplingTensors: ...@@ -575,8 +565,8 @@ class SamplingTensors:
non_blocking=True), non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device, repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True), non_blocking=True),
prompt_tokens=prompt_tokens_gpu, prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_tokens_gpu, output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu, sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device, sample_indices=sample_indices_t.to(device=device,
non_blocking=True), non_blocking=True),
......
import base64 import base64
from io import BytesIO from io import BytesIO
from typing import Optional, Union from typing import Union
from urllib.parse import urlparse
import aiohttp
import requests
from PIL import Image from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict
from vllm.version import __version__ as VLLM_VERSION
def _validate_remote_url(url: str, *, name: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise ValueError(f"Invalid '{name}': A valid '{name}' "
"must have scheme 'http' or 'https'.")
def _get_request_headers():
return {"User-Agent": f"vLLM/{VLLM_VERSION}"}
def _load_image_from_bytes(b: bytes): def _load_image_from_bytes(b: bytes):
...@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: ...@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format. By default, the image is converted into RGB format.
""" """
if image_url.startswith('http'): if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url") image_raw = global_http_connection.get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
headers = _get_request_headers()
with requests.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = response.content
image = _load_image_from_bytes(image_raw) image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): elif image_url.startswith('data:image'):
...@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: ...@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return image.convert(image_mode) return image.convert(image_mode)
class ImageFetchAiohttp: async def async_fetch_image(image_url: str,
aiohttp_client: Optional[aiohttp.ClientSession] = None *,
image_mode: str = "RGB") -> Image.Image:
@classmethod """
def get_aiohttp_client(cls) -> aiohttp.ClientSession: Asynchronously load a PIL image from a HTTP or base64 data URL.
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)
return cls.aiohttp_client
@classmethod
async def fetch_image(
cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
client = cls.get_aiohttp_client()
headers = _get_request_headers()
async with client.get(url=image_url, headers=headers) as response: By default, the image is converted into RGB format.
response.raise_for_status() """
image_raw = await response.read() if image_url.startswith('http'):
image = _load_image_from_bytes(image_raw) image_raw = await global_http_connection.async_get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url) image = _load_image_from_data_url(image_url)
else: else:
raise ValueError( raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"Invalid 'image_url': A valid 'image_url' must start " "with either 'data:image' or 'http'.")
"with either 'data:image' or 'http'.")
return image.convert(image_mode) return image.convert(image_mode)
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url) image = await async_fetch_image(image_url)
return {"image": image} return {"image": image}
......
...@@ -2,7 +2,9 @@ from typing import Optional ...@@ -2,7 +2,9 @@ from typing import Optional
import torch import torch
from .interface import Platform, PlatformEnum from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform] current_platform: Optional[Platform]
...@@ -12,7 +14,10 @@ if torch.version.cuda is not None: ...@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif torch.version.hip is not None: elif torch.version.hip is not None:
from .rocm import RocmPlatform from .rocm import RocmPlatform
current_platform = RocmPlatform() current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else: else:
current_platform = None current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform'] __all__ = ['Platform', 'PlatformEnum', 'current_platform']
import enum import enum
from typing import Tuple from typing import Tuple
import torch
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
CUDA = enum.auto() CUDA = enum.auto()
ROCM = enum.auto() ROCM = enum.auto()
TPU = enum.auto()
UNSPECIFIED = enum.auto()
class Platform: class Platform:
...@@ -16,6 +20,23 @@ class Platform: ...@@ -16,6 +20,23 @@ class Platform:
def is_rocm(self) -> bool: def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM return self._enum == PlatformEnum.ROCM
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
@staticmethod @staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
from typing import Tuple
import torch
from .interface import Platform, PlatformEnum
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise RuntimeError("TPU does not have device capability.")
@staticmethod
def inference_mode():
return torch.no_grad()
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from pydantic import Field from pydantic import Field
from typing_extensions import Annotated from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -189,18 +188,6 @@ class SamplingParams: ...@@ -189,18 +188,6 @@ class SamplingParams:
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
# Lazy import to avoid circular imports.
from vllm.usage.usage_lib import set_runtime_usage_data
set_runtime_usage_data("use_beam_search", True)
if not envs.VLLM_NO_DEPRECATION_WARNING:
logger.warning(
"[IMPORTANT] We plan to discontinue the support for beam "
"search in the next major release. Please refer to "
"https://github.com/vllm-project/vllm/issues/6226 for "
"more information. Set VLLM_NO_DEPRECATION_WARNING=1 to "
"suppress this warning.")
self._verify_beam_search() self._verify_beam_search()
else: else:
self._verify_non_beam_search() self._verify_non_beam_search()
......
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch import torch
...@@ -438,7 +439,7 @@ class SequenceGroup: ...@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
...@@ -457,24 +458,25 @@ class SequenceGroup: ...@@ -457,24 +458,25 @@ class SequenceGroup:
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
self._first_seq = next(iter(self.seqs_dict.values()))
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt return self._first_seq.prompt
@property @property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt_token_ids return self._first_seq.prompt_token_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data. # All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence. # We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data return self._first_seq.multi_modal_data
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
......
...@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple ...@@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
import torch import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, get_all_seq_ids) SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
...@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values(): for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1) data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata( return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id, request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt, is_prompt=seq_group_metadata.is_prompt,
...@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
}, },
lora_request=None, lora_request=None,
token_chunk_size=1, token_chunk_size=1,
state=state,
) )
def _split_scoring_output( def _split_scoring_output(
......
...@@ -2,17 +2,33 @@ from typing import List, Optional ...@@ -2,17 +2,33 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata) SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
logger = init_logger(__name__) logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner): class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model. """Specialized model runner for speculative decoding draft model.
...@@ -21,18 +37,9 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -21,18 +37,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time. overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain TODOs:
at this moment. Currently we adopt a temporary solution that caches the 1. Currently supports only flash-attn, add support for other attn_backends.
seq_group_metadata_list for multi-step execution, so that we can 2. Support TP > 1 (this requires some designs because we do not expect
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model). any broadcasting inside execute_model).
""" """
...@@ -71,51 +78,156 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -71,51 +78,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states=return_hidden_states, return_hidden_states=return_hidden_states,
) )
# TODO: Remove this cache when we are able to update model_input def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
# directly in advance_step. num_queries):
self.cached_seq_group_metadata_list: Optional[ assert isinstance(attn_metadata, FlashAttentionMetadata)
List[SequenceGroupMetadata]] = None
def prepare_model_input( if num_seqs != num_queries:
self, assert num_seqs > num_queries
seq_group_metadata_list: List[SequenceGroupMetadata], assert attn_metadata.use_cuda_graph
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None assert attn_metadata.num_prefills == 0
) -> ModelInputForGPUWithSamplingMetadata: assert attn_metadata.num_prefill_tokens == 0
"""A temporary solution that caches the seq_group_metadata_list assert attn_metadata.num_decode_tokens == num_seqs
for multi-step execution. assert attn_metadata.slot_mapping.shape == (num_seqs, )
TODO: In-place update model_input and remove this function.
""" assert len(attn_metadata.seq_lens) == num_seqs
self.cached_seq_group_metadata_list = seq_group_metadata_list assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
return super().prepare_model_input( assert attn_metadata.max_query_len == 1
seq_group_metadata_list, assert attn_metadata.max_prefill_seq_len == 0
finished_requests_ids=finished_requests_ids) assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
def update_model_input( assert attn_metadata.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata, self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step. # Currently, we expect "decode mode" only
TODO: In-place update model_input instead of calling assert not model_input.is_prompt
prepare_model_input.
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
""" """
if not allow_gpu_advance_step:
return False
# Append the output token to the sequence data. # We allow multi-step GPU only in decode mode
assert self.cached_seq_group_metadata_list is not None for seq_group in execute_model_req.seq_group_metadata_list:
for seq_group_metadata, sequence_group_outputs in zip( if seq_group.is_prompt:
self.cached_seq_group_metadata_list, last_output.outputs): return False
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples: # TODO: Add support for other attn backends
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] if self.attn_backend.get_name() != "flash-attn":
return False
token_id = seq_output.output_token # TODO: Add support for LORA
token_logprob = seq_output.logprobs[token_id] if self.lora_config:
return False
seq.append_token_id(token_id, token_logprob.logprob) # TODO: Add soft-tuning prompt adapter support
seq.update_num_computed_tokens(1) if self.prompt_adapter_config:
return False
return self.prepare_model_input(self.cached_seq_group_metadata_list) return True
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -125,42 +237,86 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -125,42 +237,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore, """Executes num_steps forward passes with advacement of input tensors
# we need to figure out the best way to support TP > 1 in this on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
if self.lora_config: Optimizations used:
assert model_input.lora_requests is not None 1. Input tensors are updated on the GPU directly
assert model_input.lora_mapping is not None 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
self.set_active_loras(model_input.lora_requests, them since we do batch expansion later that uses GPU outputs)
model_input.lora_mapping) 3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if self.prompt_adapter_config: # When num_steps == 1, we execute the fallback here for the GPU
assert model_input.prompt_adapter_requests is not None # advance_step, which runs prepare_inputs on CPU and for each spec
assert model_input.prompt_adapter_mapping is not None # iteration invokes this function only once
self.set_active_prompt_adapters( # (Look at multi-step-worker code)
model_input.prompt_adapter_requests, is_fallback = num_steps == 1
model_input.prompt_adapter_mapping) if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Detect exec mode
assert model_input.attn_metadata is not None
use_cuda_graph = False
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
# Attn attr defines if we use cuda graphs
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
# Get model
if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
else:
model_executable = self.model
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = [] outputs: List[SamplerOutput] = []
for step in range(num_steps): for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
# Run model
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
...@@ -181,8 +337,8 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -181,8 +337,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
)) ))
# Prepare the inputs for the next step. # Prepare inputs for the next step
if step != num_steps - 1: if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1]) model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs return outputs
...@@ -22,6 +22,9 @@ class SpeculativeProposals: ...@@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero. # The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor proposal_lens: torch.Tensor
# A flag to mark that there's no available proposals
no_proposals: bool = False
def __repr__(self): def __repr__(self):
return (f"SpeculativeProposals(" return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, " f"proposal_token_ids={self.proposal_token_ids}, "
......
...@@ -145,6 +145,10 @@ class AsyncMetricsCollector: ...@@ -145,6 +145,10 @@ class AsyncMetricsCollector:
""" """
ready_event.synchronize() ready_event.synchronize()
# update time of last collection
self._last_metrics_collect_time = self._timer()
accepted_tokens = self._aggregate_num_accepted_tokens.item() accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item() emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens draft_tokens = self._aggregate_num_draft_tokens
......
...@@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
) )
def set_include_gpu_probs_tensor(self) -> None: def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for multi_step_worker # Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.model.sampler.include_gpu_probs_tensor = True
@torch.inference_mode() @torch.inference_mode()
...@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request, indices_of_seq_with_bonus_tokens =\ expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request( self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step) execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times. # Run model sample_len times.
model_outputs: List[SamplerOutput] = [] model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner): if isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len expanded_request.num_steps = sample_len
model_outputs = self.execute_model( model_outputs = self.execute_model(
execute_model_req=expanded_request) execute_model_req=expanded_request)
else: else:
# TODO: Remove this branch once DraftModelRunner supports TP>1. # Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len): for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model( model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=expanded_request) execute_model_req=expanded_request)
...@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs=[ outputs=[
expanded_batch_output.outputs[i] expanded_batch_output.outputs[i]
for i in output_indices_to_retain for i in output_indices_to_retain
], ] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=( sampled_token_probs=(
expanded_batch_output. expanded_batch_output.
sampled_token_probs[output_indices_to_retain] sampled_token_probs[output_indices_to_retain]
......
...@@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase ...@@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model. """NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding, Current NGramWorker only implements prompt lookup decoding,
and in future we may also do RAG type drafter and other scenarios and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals. which don't rely on LLM model to give proposals.
""" """
...@@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): ...@@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None self.load_model = lambda *args, **kwargs: None
# Current only support Top1Proposer # Current NGramWorker only supports Top1Proposer
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type] weakref.proxy(self), # type: ignore[arg-type]
device=self.device, device=self.device,
......
...@@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): ...@@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
) -> Tuple[Optional[List[SamplerOutput]], bool]: ) -> Tuple[Optional[List[SamplerOutput]], bool]:
raise NotImplementedError raise NotImplementedError
def set_include_gpu_probs_tensor(self): def set_include_gpu_probs_tensor(self) -> None:
"""Implementation optional""" """Implementation optional"""
pass pass
......
...@@ -9,12 +9,12 @@ from vllm.distributed.communication_op import broadcast_tensor_dict ...@@ -9,12 +9,12 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler) SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata, HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids) get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
...@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker ...@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output, from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
...@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config: SpeculativeConfig = kwargs.get("speculative_config") speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None assert speculative_config is not None
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs) target_worker = Worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args. # Override draft-model specific worker args.
draft_worker_kwargs.update( draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config, model_config=speculative_config.draft_model_config,
...@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha) typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
return spec_decode_worker return spec_decode_worker
...@@ -107,8 +115,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -107,8 +115,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_token_acceptance_method: str, draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
ngram_prompt_lookup_max = ( ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max")) draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
...@@ -133,6 +143,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -133,6 +143,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if draft_tp == 1: if draft_tp == 1:
draft_worker_kwargs[ draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner "model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
...@@ -155,18 +167,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -155,18 +167,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with sampler=%s", logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler)) type(spec_decode_sampler))
return SpecDecodeWorker(proposer_worker, return SpecDecodeWorker(
scorer_worker, proposer_worker,
disable_by_batch_size=disable_by_batch_size, scorer_worker,
spec_decode_sampler=spec_decode_sampler) disable_logprobs=disable_logprobs,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
def __init__( def __init__(
self, self,
proposer_worker: ProposerWorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler, spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None, disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
...@@ -183,15 +200,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -183,15 +200,22 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler. instance of RejectionSampler or TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_by_batch_size: If the batch size is larger than this, disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests. disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set metrics_collector: Helper class for collecting metrics; can be set
for testing purposes. for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf") self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler self.spec_decode_sampler
) if metrics_collector is None else metrics_collector ) if metrics_collector is None else metrics_collector
...@@ -206,12 +230,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -206,12 +230,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.spec_decode_sampler.probs_dtype self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initiazliation. # Lazy initialization.
self.scorer: SpeculativeScorer self.scorer: SpeculativeScorer
# Hidden states from target model to pass to proposer # Hidden states from target model to pass to proposer
# in the subsequent step. # in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
...@@ -347,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -347,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) == 0 or disable_all_speculation: ) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req, return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation) skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req, return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots) num_lookahead_slots)
...@@ -381,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -381,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker. # this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0 seq_group_metadata.num_speculative_tokens = 0
def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
"""
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids):
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
))
return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]: skip_proposer: bool) -> List[SamplerOutput]:
...@@ -407,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -407,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states.update( self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states) execute_model_req.seq_group_metadata_list, hidden_states)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
sampler_output)
# Clear device tensors from sampler output. This reduces communication # Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers. # overhead when the engine runs in a different process than the workers.
sampler_output.probs = None sampler_output.sampled_token_probs = None
sampler_output.sampled_tokens = None sampler_output.sampled_token_ids = None
sampler_output.logprobs = None sampler_output.logprobs = None
return [sampler_output] return [sampler_output_to_return]
def _run_non_driver_rank(self) -> bool: def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used """Run proposer and verifier model in non-driver workers. This is used
...@@ -461,11 +526,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -461,11 +526,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step) execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
execute_model_req, execute_model_req,
proposals, proposals,
) )
accepted_token_ids, target_logprobs = self._verify_tokens( accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores, execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots) proposals, execute_model_req.num_lookahead_slots)
...@@ -521,11 +590,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -521,11 +590,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Get proposed tokens. # Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
accepted_token_ids = self.spec_decode_sampler( accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs, target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids, bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs, draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids, draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
) )
# Append output tokens from non-speculative sequences to # Append output tokens from non-speculative sequences to
...@@ -569,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -569,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
the same number of outputs. the same number of outputs.
""" """
batch_size, num_steps = accepted_token_ids.shape batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
if self._disable_logprobs:
# Get the logprobs/rank of the accepted tokens. # We are skipping the logprobs. Hence don't serialize the
(accepted_token_id_ranks_by_step, # logprobs related tensors from the GPU. Instead create
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( # empty/dummy lists.
logprob_tensor=target_logprobs_by_step, (accepted_token_id_ranks_by_step,
sampled_token_ids=accepted_token_ids_by_step, accepted_token_id_logprobs_by_step,
) topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
# Get the top-k logprobs (which may or may not include the logprob of batch_size, num_steps,
# the accepted token). self.scorer_worker.model_config.max_logprobs)
(topk_logprobs_by_step, else:
topk_indices_by_step) = target_logprobs_by_step.topk( # Organize input tensors by step instead of by sequence.
k=self.scorer_worker.model_config.max_logprobs, target_logprobs_by_step = target_logprobs.transpose(0, 1)
dim=-1, # Serialize all tensors into Python lists.
) (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the # Get the sequence ids and num_logprobs (sampling parameter) in the
# batch. # batch.
...@@ -596,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -596,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists. # Serialize tensor to CPU Python list.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis. # Construct the output on a per-step, per-sequence basis.
sampler_output_list: List[SamplerOutput] = [] sampler_output_list: List[SamplerOutput] = []
...@@ -645,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -645,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
0].spec_decode_worker_metrics = maybe_rejsample_metrics 0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list return sampler_output_list
def _create_dummy_logprob_lists(
self,
batch_size: int,
num_steps: int,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step = [[-1] * batch_size
for _ in range(num_steps)]
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
for _ in range(num_steps)]
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _create_logprob_lists_from_tensors(
self,
target_logprobs_by_step: torch.Tensor,
accepted_token_ids_by_step: torch.Tensor,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step_tensor,
accepted_token_id_logprobs_by_step_tensor
) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(topk_logprobs_by_step_tensor,
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
k=num_top_k,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step_tensor.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step_tensor.tolist())
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
""" """
Removes the finished requests and their associated sequence ids from Removes the finished requests and their associated sequence ids from
......
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
class TargetModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def __init__(self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input.sampling_metadata.skip_sampler_cpu_output = (
self.disable_logprobs)
return model_input
...@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_token_ids=proposal_tokens, proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs, proposal_probs=proposal_probs,
proposal_lens=proposal_lens, proposal_lens=proposal_lens,
) no_proposals=maybe_sampler_output is None)
return proposals return proposals
...@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len # Currently only proposal lens of 0 or the global batch proposal len
# are supported. # are supported.
# If max_proposal_len is defined, then we shall no exccess this # If max_proposal_len is defined, then we shall no exceed this
# quota for nonzero_proposal # quota for nonzero_proposal
new_k = 0 new_k = 0
if (self.max_proposal_len is None if (self.max_proposal_len is None
...@@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens: List[int], proposal_lens: List[int],
nonzero_proposal_len_indices: List[int], nonzero_proposal_len_indices: List[int],
sampler_transposed: bool, sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with """After speculations are produced, merge the speculation results with
the skipped sequences. the skipped sequences.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment