Commit 31f6b24f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori

parents 89d1dd57 25f560a6
......@@ -38,6 +38,14 @@ except (ImportError, OSError):
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
......@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield from streamer.get_tensors()
def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()
device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
......
......@@ -15,21 +15,25 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from itertools import chain
from typing import Iterable, Literal, Optional, Union
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsQuant
from .utils import maybe_prefix
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
......@@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
# Transformers kwargs
scaling: Optional[float] = None,
# vLLM kwargs
attention_instances: Optional[list[Attention]] = None,
attention_instances: Optional[dict[Attention]] = None,
**kwargs):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
......@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def replace_linear_class(
linear: nn.Linear,
style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
linear: nn.Linear, style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
......@@ -105,7 +109,7 @@ def replace_linear_class(
)
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
......@@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
super().__init__()
logger.info("Using Transformers backend.")
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config: CacheConfig = vllm_config.cache_config
device_config: DeviceConfig = vllm_config.device_config
model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
config,
attn_implementation="vllm",
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix
# MLP modifications
self.apply_base_model_tp_plan(self.model)
self.pipeline_parallel()
self.tensor_parallel()
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_size <= 1:
return
# Attention modifications (assumes 1 attention op per hidden layer)
num_heads = model_config.get_num_attention_heads(parallel_config)
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [
if not self.model.supports_pp_plan:
raise ValueError(
f"{type(self.model)} does not support pipeline parallel yet!")
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!")
if module_list_idx is None:
raise ValueError(
f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
and self.pp_group.is_last_rank):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer(return_tuple=True)
# Layers after module list
for name in pp_plan[module_list_idx + 1:]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if self.tp_size > 1 and self.config.base_model_tp_plan is None:
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(self.model)
def create_attention_instances(self) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
num_heads=num_heads,
head_size=head_size,
......@@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=cache_config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
]
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = get_sampler()
prefix=f"{i}.attn")
for i in range(start, end)
}
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
def init_buffers(self, module: nn.Module):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
If a `buffer` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
"""
if (self.config.base_model_tp_plan is None
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
module.set_input_embeddings(new_module)
for name, buffer in module.named_buffers(recurse=False):
if buffer.device == torch.device("meta"):
new_buffer = getattr(type(module)(self.config), name)
setattr(module, name, new_buffer)
for child in module.children():
self.init_buffers(child)
def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
module.to_empty(device=self.device_config.device)
return # We can stop recursing because to_empty is recursive
for child in module.children():
self.meta_to_empty(child)
def forward(
self,
input_ids: torch.Tensor,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(
input_ids[None, ...],
if not get_pp_group().is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=positions[None, ...],
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
return model_output
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def compute_logits(
self,
......@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
params_dict = dict(self.named_parameters())
loaded_params = set[str]()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
# Necessary for some models which use remote code
if not name.startswith(prefix := self.model.base_model_prefix):
name = maybe_prefix(prefix, name)
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
......
......@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
def __init__(self, *args, **kwargs):
super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input, ) if self.return_tuple else input
_CPU_OFFLOAD_BYTES = 0
......@@ -650,4 +660,4 @@ def cast_overflow_tensors(
if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors
\ No newline at end of file
return tensors
......@@ -92,7 +92,7 @@ class CpuPlatform(Platform):
if kv_cache_space == 0:
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
"for CPU backend is not set, using 4 by default.")
else:
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
......
......@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import import_pynvml
......@@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
......@@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3):
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
......
......@@ -369,8 +369,9 @@ class SamplingParams(
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self._all_stop_token_ids = set(self.stop_token_ids)
self._all_stop_token_ids.update(self.stop_token_ids)
def _verify_args(self) -> None:
if not isinstance(self.n, int):
......
......@@ -37,7 +37,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, TypeVar, Union)
Optional, Type, TypeVar, Union)
from uuid import uuid4
import cloudpickle
......@@ -1544,9 +1544,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
class ClassRegistry(UserDict[Type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
def __getitem__(self, key: Type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
......
......@@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
......@@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
def forward(
self,
......
......@@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
......@@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......
......@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
......@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id)
if not stopped:
......
......@@ -21,14 +21,15 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
......@@ -176,11 +177,14 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> asyncio.Queue[RequestOutput]:
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
# Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
......@@ -189,17 +193,15 @@ class AsyncLLM(EngineClient):
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
if params.n == 1:
await self._add_request(request, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(n):
for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
......@@ -207,7 +209,7 @@ class AsyncLLM(EngineClient):
async def _add_request(self, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue)
......@@ -272,15 +274,7 @@ class AsyncLLM(EngineClient):
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get()
# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
out.add(next_out)
else:
out = next_out
out = q.get_nowait() or await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
......
......@@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
......
......@@ -17,6 +17,46 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[RequestOutput] = None
self.ready = asyncio.Event()
def put(self, output: RequestOutput) -> None:
if self.output is None:
self.output = output
self.ready.set()
elif self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output
async def get(self) -> RequestOutput:
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
return output
def get_nowait(self) -> Optional[RequestOutput]:
output = self.output
if output is not None:
self.output = None
self.ready.clear()
return output
@dataclass
class OutputProcessorOutput:
......@@ -39,7 +79,7 @@ class RequestState:
detokenizer: IncrementalDetokenizer,
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]],
queue: Optional[RequestOutputCollector],
log_stats: bool,
):
self.request_id = request_id
......@@ -66,7 +106,7 @@ class RequestState:
request: EngineCoreRequest,
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[asyncio.Queue[RequestOutput]],
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
......@@ -105,9 +145,7 @@ class RequestState:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
......@@ -219,7 +257,7 @@ class OutputProcessor:
request: EngineCoreRequest,
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[asyncio.Queue[RequestOutput]] = None,
queue: Optional[RequestOutputCollector] = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
......@@ -285,19 +323,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
......@@ -306,8 +332,7 @@ class OutputProcessor:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
......@@ -315,7 +340,7 @@ class OutputProcessor:
new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output)
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
......
......@@ -4,7 +4,6 @@ import time
from collections.abc import Mapping
from typing import Optional, Union
import vllm.platforms
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
......@@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.utils import validate_structured_output_request
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
validate_structured_output_request_xgrammar)
class Processor:
......@@ -120,7 +122,9 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
......@@ -134,10 +138,31 @@ class Processor:
else:
params.guided_decoding.backend = engine_level_backend
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
validate_structured_output_request(params)
# Request content validation
if engine_level_backend == "xgrammar":
# xgrammar with no fallback
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
elif engine_level_backend == "auto":
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_structured_output_request_xgrammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
if params.guided_decoding.backend == "guidance":
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
def process_inputs(
self,
......
......@@ -100,15 +100,8 @@ class IterationStats:
num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling and num_new_generation_tokens > 0:
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)
......@@ -123,16 +116,12 @@ class IterationStats:
# Process the batch-level "new tokens" engine core event
if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.first_token_ts = engine_core_timestamp
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
......
......@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
self.selected_token_ranks.tolist(),
)
@staticmethod
def empty_cpu(num_positions: int,
num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids = torch.empty(
(num_positions, num_tokens_per_position),
dtype=torch.int32,
device="cpu")
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
selected_token_ranks = torch.empty(num_positions,
dtype=torch.int32,
device="cpu")
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=selected_token_ranks,
)
@dataclass
class SamplerOutput:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
def compiled_softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor] = 1.0,
) -> torch.Tensor:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch._dynamo.mark_dynamic(logits, index=0)
if isinstance(temperature, torch.Tensor):
torch._dynamo.mark_dynamic(temperature, index=0)
return _softmax(logits, temperature)
@torch.compile
def _softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor],
) -> torch.Tensor:
logits = logits / temperature
return torch.softmax(logits, dim=-1, dtype=torch.float32)
......@@ -8,7 +8,7 @@ import triton.language as tl
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.utils import compiled_softmax
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
......@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
......@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
......@@ -245,25 +248,80 @@ def compute_probs(
return logits
num_tokens = logits.shape[0]
batch_size = cu_num_draft_tokens.shape[0]
expanded_temperature = torch.empty(
(num_tokens, 1),
dtype=torch.float32,
device=logits.device,
)
expand_kernel[(batch_size, )](
expanded_temperature,
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
GREEDY_TEMPERATURE, # replace_from
1, # replace_to
MAX_NUM_TOKENS=MAX_SPEC_LEN,
num_warps=1,
num_tokens,
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
output_prob = compiled_softmax(logits, expanded_temperature)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob
def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x
def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
......
......@@ -137,7 +137,7 @@ class Sampler(nn.Module):
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
......
......@@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
# Spec decode doesn't support top_p/top_k sampling.
return False
elif req_id in input_batch.min_p_reqs:
if req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment