Commit 31f6b24f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori

parents 89d1dd57 25f560a6
...@@ -38,6 +38,14 @@ except (ImportError, OSError): ...@@ -38,6 +38,14 @@ except (ImportError, OSError):
SafetensorsStreamer = runai_model_streamer.placeholder_attr( SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer") "SafetensorsStreamer")
try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
logger = init_logger(__name__) logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users # use system-level temp directory for file locks, so that multiple users
...@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator( ...@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield from streamer.get_tensors() yield from streamer.get_tensors()
def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()
device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def pt_weights_iterator( def pt_weights_iterator(
hf_weights_files: List[str], hf_weights_files: List[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
......
...@@ -15,21 +15,25 @@ ...@@ -15,21 +15,25 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from itertools import chain
from typing import Iterable, Literal, Optional, Union from typing import Iterable, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
from vllm.distributed import get_tensor_model_parallel_world_size ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import maybe_prefix from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,7 +58,7 @@ def vllm_flash_attention_forward( ...@@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
# Transformers kwargs # Transformers kwargs
scaling: Optional[float] = None, scaling: Optional[float] = None,
# vLLM kwargs # vLLM kwargs
attention_instances: Optional[list[Attention]] = None, attention_instances: Optional[dict[Attention]] = None,
**kwargs): **kwargs):
self_attn = attention_instances[module.layer_idx] self_attn = attention_instances[module.layer_idx]
if scaling is not None: if scaling is not None:
...@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): ...@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def replace_linear_class( def replace_linear_class(
linear: nn.Linear, linear: nn.Linear, style: Literal["colwise", "rowwise"],
style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: ) -> Union[ColumnParallelLinear, RowParallelLinear]:
""" """
Replace nn.Linear with one of vLLM's tensor parallel linear classes. Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args: Args:
linear (nn.Linear): `nn.Linear` to be replaced. linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise". style (str): Tensor parallel style of the new linear, e.g. "colwise".
...@@ -105,7 +109,7 @@ def replace_linear_class( ...@@ -105,7 +109,7 @@ def replace_linear_class(
) )
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens" embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it ] # TODO transformers will have a util to get it
...@@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
super().__init__() super().__init__()
logger.info("Using Transformers backend.") logger.info("Using Transformers backend.")
config = vllm_config.model_config.hf_config config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config: CacheConfig = vllm_config.cache_config
model_config = vllm_config.model_config device_config: DeviceConfig = vllm_config.device_config
parallel_config = vllm_config.parallel_config model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config self.config = config
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size() self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size() self.unpadded_vocab_size = model_config.get_vocab_size()
self.model: PreTrainedModel = AutoModel.from_config( self.pp_group = get_pp_group()
self.config, self.pp_size = self.pp_group.world_size
attn_implementation="vllm", self.pp_rank = self.pp_group.rank_in_group
torch_dtype=vllm_config.model_config.dtype, self.tp_size = get_tensor_model_parallel_world_size()
trust_remote_code=vllm_config.model_config.trust_remote_code,
) # Use meta device to delay allocating GPU tensors
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
config,
attn_implementation="vllm",
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix prefix = self.model.base_model_prefix
# MLP modifications self.pipeline_parallel()
self.apply_base_model_tp_plan(self.model) self.tensor_parallel()
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_size <= 1:
return
# Attention modifications (assumes 1 attention op per hidden layer) if not self.model.supports_pp_plan:
num_heads = model_config.get_num_attention_heads(parallel_config) raise ValueError(
head_size = model_config.get_head_size() f"{type(self.model)} does not support pipeline parallel yet!")
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [ module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!")
if module_list_idx is None:
raise ValueError(
f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
and self.pp_group.is_last_rank):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer(return_tuple=True)
# Layers after module list
for name in pp_plan[module_list_idx + 1:]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if self.tp_size > 1 and self.config.base_model_tp_plan is None:
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(self.model)
def create_attention_instances(self) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention( Attention(
num_heads=num_heads, num_heads=num_heads,
head_size=head_size, head_size=head_size,
...@@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
# Transformers, it's updated in vllm_flash_attention_forward # Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5, scale=head_size**-0.5,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
cache_config=cache_config, cache_config=self.cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers) prefix=f"{i}.attn")
] for i in range(start, end)
}
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = get_sampler()
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): def init_buffers(self, module: nn.Module):
""" """
Apply the base model tensor parallelization plan to a module. If a `buffer` is on the `meta` device, then its parent
Currently only supports linear layers. `module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
""" """
if (self.config.base_model_tp_plan is None for name, buffer in module.named_buffers(recurse=False):
and get_tensor_model_parallel_world_size() > 1): if buffer.device == torch.device("meta"):
raise ValueError( new_buffer = getattr(type(module)(self.config), name)
"Trying to run tensor parallelization but the model does not " setattr(module, name, new_buffer)
"support it yet!") for child in module.children():
self.init_buffers(child)
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name) def meta_to_empty(self, module: nn.Module):
for pattern, style in self.config.base_model_tp_plan.items(): tensors = list(chain(module.buffers(), module.parameters()))
if re.match(pattern, qual_name) and isinstance( if tensors and all(t.device == torch.device("meta") for t in tensors):
child_module, nn.Linear): module.to_empty(device=self.device_config.device)
new_module = replace_linear_class(child_module, style, return # We can stop recursing because to_empty is recursive
self.quant_config) for child in module.children():
setattr(module, child_name, new_module) self.meta_to_empty(child)
log_replacement(qual_name, child_module, new_module)
else:
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
module.set_input_embeddings(new_module)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model( if not get_pp_group().is_first_rank:
input_ids[None, ...], assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False, use_cache=False,
position_ids=positions[None, ...], position_ids=positions[None, ...],
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now return_dict=False)[0][0, ...] # we remove batch dimension for now
return model_output
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def compute_logits( def compute_logits(
self, self,
...@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set[str]() loaded_params = set[str]()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name not in params_dict: # Necessary for some models which use remote code
name = f"{self.model.base_model_prefix}.{name}" if not name.startswith(prefix := self.model.base_model_prefix):
name = maybe_prefix(prefix, name)
if is_pp_missing_parameter(name, self):
continue
if name in params_dict: if name in params_dict:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input, ) if self.return_tuple else input
_CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_BYTES = 0
...@@ -650,4 +660,4 @@ def cast_overflow_tensors( ...@@ -650,4 +660,4 @@ def cast_overflow_tensors(
if tensors.isinf().any() or tensors.isnan().any(): if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors return tensors
\ No newline at end of file
...@@ -92,7 +92,7 @@ class CpuPlatform(Platform): ...@@ -92,7 +92,7 @@ class CpuPlatform(Platform):
if kv_cache_space == 0: if kv_cache_space == 0:
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning( logger.warning(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
"for CPU backend is not set, using 4 by default.") "for CPU backend is not set, using 4 by default.")
else: else:
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
......
...@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec ...@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import import_pynvml from vllm.utils import import_pynvml
...@@ -258,7 +257,7 @@ class CudaPlatformBase(Platform): ...@@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
try: try:
import vllm.vllm_flash_attn # noqa: F401 import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \ supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes() FlashAttentionBackend.get_supported_head_sizes()
...@@ -269,10 +268,9 @@ class CudaPlatformBase(Platform): ...@@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
target_backend = _Backend.XFORMERS target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8")) and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3): if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache." "Cannot use FlashAttention backend for FP8 KV cache.")
)
logger.warning( logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for " "Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable " "better performance by setting environment variable "
......
...@@ -369,8 +369,9 @@ class SamplingParams( ...@@ -369,8 +369,9 @@ class SamplingParams(
self.top_k = -1 self.top_k = -1
self.min_p = 0.0 self.min_p = 0.0
self._verify_greedy_sampling() self._verify_greedy_sampling()
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self._all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids.update(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int): if not isinstance(self.n, int):
......
...@@ -37,7 +37,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, ...@@ -37,7 +37,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, TypeVar, Union) Optional, Type, TypeVar, Union)
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
...@@ -1544,9 +1544,9 @@ class LazyDict(Mapping[str, T], Generic[T]): ...@@ -1544,9 +1544,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return len(self._factory) return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]): class ClassRegistry(UserDict[Type[T], _V]):
def __getitem__(self, key: type[T]) -> _V: def __getitem__(self, key: Type[T]) -> _V:
for cls in key.mro(): for cls in key.mro():
if cls in self.data: if cls in self.data:
return self.data[cls] return self.data[cls]
......
...@@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
else: else:
self.sliding_window = (sliding_window - 1, 0) self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None: if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0 logits_soft_cap = 0
...@@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashAttentionImpl") "FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
def forward( def forward(
self, self,
......
...@@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, ...@@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
...@@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
......
...@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface): ...@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
# Transmit partial if chunked prefill & prompt logprobs is enabled if new_token_ids:
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
outputs.append( outputs.append(
EngineCoreOutput( EngineCoreOutput(
...@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface): ...@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason, stop_reason=request.stop_reason,
events=request.take_events())) events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id) self.scheduled_req_ids.remove(request.request_id)
if not stopped: if not stopped:
......
...@@ -21,14 +21,15 @@ from vllm.lora.request import LoRARequest ...@@ -21,14 +21,15 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
...@@ -176,11 +177,14 @@ class AsyncLLM(EngineClient): ...@@ -176,11 +177,14 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> asyncio.Queue[RequestOutput]: ) -> RequestOutputCollector:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
# Create a new output queue for the request. assert isinstance(params, SamplingParams), \
queue: asyncio.Queue[RequestOutput] = asyncio.Queue() "Pooling is not supported in V1"
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request. # Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params, request = self.processor.process_inputs(request_id, prompt, params,
...@@ -189,17 +193,15 @@ class AsyncLLM(EngineClient): ...@@ -189,17 +193,15 @@ class AsyncLLM(EngineClient):
prompt_adapter_request, prompt_adapter_request,
priority) priority)
n = params.n if isinstance(params, SamplingParams) else 1 if params.n == 1:
if n == 1:
await self._add_request(request, None, 0, queue) await self._add_request(request, None, 0, queue)
return queue return queue
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params) parent_request = ParentRequest(request_id, params)
for idx in range(n): for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx) request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request) child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id child_request.request_id = request_id
child_request.sampling_params = params child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue) await self._add_request(child_request, parent_request, idx, queue)
...@@ -207,7 +209,7 @@ class AsyncLLM(EngineClient): ...@@ -207,7 +209,7 @@ class AsyncLLM(EngineClient):
async def _add_request(self, request: EngineCoreRequest, async def _add_request(self, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], index: int, parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]): queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process). # Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue) self.output_processor.add_request(request, parent_req, index, queue)
...@@ -272,15 +274,7 @@ class AsyncLLM(EngineClient): ...@@ -272,15 +274,7 @@ class AsyncLLM(EngineClient):
while not finished: while not finished:
# Note: drain queue without await if possible (avoids # Note: drain queue without await if possible (avoids
# task switching under load which helps performance). # task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get() out = q.get_nowait() or await q.get()
# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
out.add(next_out)
else:
out = next_out
# Note: both OutputProcessor and EngineCore handle their # Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished. # own request cleanup based on finished.
......
...@@ -115,7 +115,6 @@ class LogprobsProcessor: ...@@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens, num_logprobs = logprobs.shape num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors. # Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist() prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist() prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist() token_ids = token_ids.tolist()
......
...@@ -17,6 +17,46 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, ...@@ -17,6 +17,46 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats) RequestStateStats)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[RequestOutput] = None
self.ready = asyncio.Event()
def put(self, output: RequestOutput) -> None:
if self.output is None:
self.output = output
self.ready.set()
elif self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output
async def get(self) -> RequestOutput:
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
return output
def get_nowait(self) -> Optional[RequestOutput]:
output = self.output
if output is not None:
self.output = None
self.ready.clear()
return output
@dataclass @dataclass
class OutputProcessorOutput: class OutputProcessorOutput:
...@@ -39,7 +79,7 @@ class RequestState: ...@@ -39,7 +79,7 @@ class RequestState:
detokenizer: IncrementalDetokenizer, detokenizer: IncrementalDetokenizer,
max_tokens_param: Optional[int], max_tokens_param: Optional[int],
arrival_time: float, arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]], queue: Optional[RequestOutputCollector],
log_stats: bool, log_stats: bool,
): ):
self.request_id = request_id self.request_id = request_id
...@@ -66,7 +106,7 @@ class RequestState: ...@@ -66,7 +106,7 @@ class RequestState:
request: EngineCoreRequest, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], parent_req: Optional[ParentRequest],
request_index: int, request_index: int,
queue: Optional[asyncio.Queue[RequestOutput]], queue: Optional[RequestOutputCollector],
log_stats: bool, log_stats: bool,
) -> "RequestState": ) -> "RequestState":
if not request.sampling_params.detokenize: if not request.sampling_params.detokenize:
...@@ -105,9 +145,7 @@ class RequestState: ...@@ -105,9 +145,7 @@ class RequestState:
finished = finish_reason is not None finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore if not finished and final_only:
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
# Only the final output is required in FINAL_ONLY mode. # Only the final output is required in FINAL_ONLY mode.
return None return None
...@@ -219,7 +257,7 @@ class OutputProcessor: ...@@ -219,7 +257,7 @@ class OutputProcessor:
request: EngineCoreRequest, request: EngineCoreRequest,
parent_req: Optional[ParentRequest] = None, parent_req: Optional[ParentRequest] = None,
request_index: int = 0, request_index: int = 0,
queue: Optional[asyncio.Queue[RequestOutput]] = None, queue: Optional[RequestOutputCollector] = None,
) -> None: ) -> None:
request_id = request.request_id request_id = request.request_id
if request_id in self.request_states: if request_id in self.request_states:
...@@ -285,19 +323,7 @@ class OutputProcessor: ...@@ -285,19 +323,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason stop_reason = engine_core_output.stop_reason
# TODO(andy): prompt logprobs + chunked prefill can req_state.is_prefilling = False
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
# 2) Detokenize the token ids into text and perform stop checks. # 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update( stop_string = req_state.detokenizer.update(
...@@ -306,8 +332,7 @@ class OutputProcessor: ...@@ -306,8 +332,7 @@ class OutputProcessor:
finish_reason = FinishReason.STOP finish_reason = FinishReason.STOP
stop_reason = stop_string stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request, # 3) Compute sample and prompt logprobs for request, if required.
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output) req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects. # 4) Create and handle RequestOutput objects.
...@@ -315,7 +340,7 @@ class OutputProcessor: ...@@ -315,7 +340,7 @@ class OutputProcessor:
new_token_ids, finish_reason, stop_reason): new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None: if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate(). # AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output) req_state.queue.put(request_output)
else: else:
# LLMEngine: return list of RequestOutputs. # LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output) request_outputs.append(request_output)
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
import vllm.platforms
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
...@@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.utils import validate_structured_output_request from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
validate_structured_output_request_xgrammar)
class Processor: class Processor:
...@@ -120,7 +122,9 @@ class Processor: ...@@ -120,7 +122,9 @@ class Processor:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"] supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends: if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is " raise ValueError(f"Only {supported_backends} structured output is "
...@@ -134,10 +138,31 @@ class Processor: ...@@ -134,10 +138,31 @@ class Processor:
else: else:
params.guided_decoding.backend = engine_level_backend params.guided_decoding.backend = engine_level_backend
if vllm.platforms.current_platform.is_tpu(): # Request content validation
raise ValueError("Structured output is not supported on TPU.")
if engine_level_backend == "xgrammar":
validate_structured_output_request(params) # xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance":
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
def process_inputs( def process_inputs(
self, self,
......
...@@ -100,15 +100,8 @@ class IterationStats: ...@@ -100,15 +100,8 @@ class IterationStats:
num_new_generation_tokens = len(output.new_token_ids) num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
if is_prefilling and num_new_generation_tokens > 0: if is_prefilling:
# TODO(andy): we used to assert that num_new_generation_tokens assert num_new_generation_tokens > 0
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
self.num_prompt_tokens += prompt_len self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time) first_token_latency = self._time_since(req_stats.arrival_time)
...@@ -123,16 +116,12 @@ class IterationStats: ...@@ -123,16 +116,12 @@ class IterationStats:
# Process the batch-level "new tokens" engine core event # Process the batch-level "new tokens" engine core event
if is_prefilling: if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above req_stats.first_token_ts = engine_core_timestamp
if num_new_generation_tokens > 0:
req_stats.first_token_ts = engine_core_timestamp
else: else:
tpot = engine_core_timestamp - req_stats.last_token_ts tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot) self.time_per_output_tokens_iter.append(tpot)
# TODO: re-enable no-output-for-partial-prefills invariant as above req_stats.last_token_ts = engine_core_timestamp
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats, is_prefilling: bool, req_stats: RequestStateStats,
......
...@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple): ...@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
self.selected_token_ranks.tolist(), self.selected_token_ranks.tolist(),
) )
@staticmethod
def empty_cpu(num_positions: int,
num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids = torch.empty(
(num_positions, num_tokens_per_position),
dtype=torch.int32,
device="cpu")
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
selected_token_ranks = torch.empty(num_positions,
dtype=torch.int32,
device="cpu")
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=selected_token_ranks,
)
@dataclass @dataclass
class SamplerOutput: class SamplerOutput:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
def compiled_softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor] = 1.0,
) -> torch.Tensor:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch._dynamo.mark_dynamic(logits, index=0)
if isinstance(temperature, torch.Tensor):
torch._dynamo.mark_dynamic(temperature, index=0)
return _softmax(logits, temperature)
@torch.compile
def _softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor],
) -> torch.Tensor:
logits = logits / temperature
return torch.softmax(logits, dim=-1, dtype=torch.float32)
...@@ -8,7 +8,7 @@ import triton.language as tl ...@@ -8,7 +8,7 @@ import triton.language as tl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.utils import compiled_softmax from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module): ...@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
Shape is [num_tokens, vocab_size]. Here, probabilities from Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because different requests are flattened into a single tensor because
this is the shape of the output logits. this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor): bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1]. A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all Bonus tokens are added to the end of the sequence if all
...@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module): ...@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
''' '''
assert metadata.max_spec_len <= MAX_SPEC_LEN assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs( target_probs = compute_probs(
target_logits, target_logits,
metadata.cu_num_draft_tokens, metadata.cu_num_draft_tokens,
...@@ -245,25 +248,80 @@ def compute_probs( ...@@ -245,25 +248,80 @@ def compute_probs(
return logits return logits
num_tokens = logits.shape[0] num_tokens = logits.shape[0]
batch_size = cu_num_draft_tokens.shape[0] temperature = expand_batch_to_tokens(
expanded_temperature = torch.empty(
(num_tokens, 1),
dtype=torch.float32,
device=logits.device,
)
expand_kernel[(batch_size, )](
expanded_temperature,
sampling_metadata.temperature, sampling_metadata.temperature,
cu_num_draft_tokens, cu_num_draft_tokens,
GREEDY_TEMPERATURE, # replace_from num_tokens,
1, # replace_to replace_from=GREEDY_TEMPERATURE,
MAX_NUM_TOKENS=MAX_SPEC_LEN, replace_to=1,
num_warps=1,
) )
output_prob = compiled_softmax(logits, expanded_temperature) # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob return output_prob
def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x
def generate_uniform_probs( def generate_uniform_probs(
num_tokens: int, num_tokens: int,
num_draft_tokens: list[int], num_draft_tokens: list[int],
......
...@@ -137,7 +137,7 @@ class Sampler(nn.Module): ...@@ -137,7 +137,7 @@ class Sampler(nn.Module):
Gather logprobs for topk and sampled/prompt token. Gather logprobs for topk and sampled/prompt token.
Args: Args:
logits: (num tokens) x (vocab) tensor logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to num_logprobs: minimum number of logprobs to
retain per token retain per token
token_ids: prompt tokens (if prompt logprobs) token_ids: prompt tokens (if prompt logprobs)
......
...@@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch ...@@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs: if req_id in input_batch.min_p_reqs:
# Spec decode doesn't support top_p/top_k sampling.
return False
elif req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling. # Spec decode doesn't support min_p sampling.
return False return False
elif (req_id in input_batch.frequency_penalties_reqs elif (req_id in input_batch.frequency_penalties_reqs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment