Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)
from typing import Dict
import torch
import torch.nn as nn
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module):
def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer.")
self.forward = self.forward_native
else:
self.forward = self.forward_native
def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
probs: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (no_top_k and no_top_p)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)
if no_top_k:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
elif no_top_p:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))
# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if not no_top_k:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1)
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.v1.outputs import SamplerOutput from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward( def forward(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature) needs_logprobs = sampling_metadata.max_num_logprobs > 0
logits = self.apply_top_k_top_p(logits, sampling_metadata) if needs_logprobs:
# NOTE(woosuk): Use the original logits (before any penalties or
probs = self.get_probs(logits) # temperature scaling) for the top-k logprobs.
sampled = self.sample(probs, sampling_metadata) # This is different from the V0 sampler, which uses the logits that
# Use int32 to reduce the tensor size. # is used for sampling (after penalties and temperature scaling).
sampled = sampled.to(torch.int32) # NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
if sampling_metadata.max_num_logprobs > 0: # the logits tensor for memory efficiency).
logprobs = self.get_logprobs(logits) topk_logprobs, topk_indices = self.get_topk_logprobs(
# FIXME: Mask the sampled token_id, get topk logprobs, logits, sampling_metadata)
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else: else:
topk_logprobs = None topk_logprobs = None
topk_indices = None topk_indices = None
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# NOTE: CPU-GPU synchronization happens here. # NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(), sampled_token_ids=sampled.tolist(),
...@@ -52,71 +65,37 @@ class Sampler(nn.Module): ...@@ -52,71 +65,37 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
temp: torch.Tensor, temp: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero. # Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1)) logits.div_(temp.unsqueeze(dim=1))
return logits return logits
def apply_top_k_top_p( def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
return _apply_top_k_top_p( assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)
random_sampled = self.topk_topp_sampler(
logits, logits,
sampling_metadata.generators,
sampling_metadata.no_top_k, sampling_metadata.no_top_k,
sampling_metadata.top_k, sampling_metadata.top_k,
sampling_metadata.no_top_p, sampling_metadata.no_top_p,
sampling_metadata.top_p, sampling_metadata.top_p,
) )
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample(
self,
probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random: if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators) return random_sampled
greedy_sampled = self.greedy_sample(probs) greedy_sampled = self.greedy_sample(logits)
random_sampled = self.random_sample(probs,
sampling_metadata.generators)
sampled = torch.where( sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS, sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, greedy_sampled,
...@@ -124,36 +103,34 @@ class Sampler(nn.Module): ...@@ -124,36 +103,34 @@ class Sampler(nn.Module):
) )
return sampled return sampled
def get_topk_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
# TODO(woosuk): Optimize this with a custom kernel. def apply_penalties(
def _apply_top_k_top_p( self,
logits: torch.Tensor, logits: torch.Tensor,
no_top_k: bool, sampling_metadata: SamplingMetadata,
k: torch.Tensor, ) -> torch.Tensor:
no_top_p: bool, apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
p: torch.Tensor, sampling_metadata.stop_token_ids,
) -> torch.Tensor: sampling_metadata.min_tokens)
if no_top_k and no_top_p: if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits, sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
...@@ -102,27 +101,3 @@ def make_zmq_socket( ...@@ -102,27 +101,3 @@ def make_zmq_socket(
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=0)
K = TypeVar('K')
V = TypeVar('V')
class LRUDictCache(Generic[K, V]):
def __init__(self, size: int):
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size
def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)
...@@ -43,12 +43,14 @@ class InputBatch: ...@@ -43,12 +43,14 @@ class InputBatch:
max_num_blocks_per_req: int, max_num_blocks_per_req: int,
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
vocab_size: int,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device self.device = device
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {} self.req_id_to_index: Dict[str, int] = {}
...@@ -63,6 +65,7 @@ class InputBatch: ...@@ -63,6 +65,7 @@ class InputBatch:
) )
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
# Attention-related. # Attention-related.
self.block_table = torch.zeros( self.block_table = torch.zeros(
...@@ -110,6 +113,50 @@ class InputBatch: ...@@ -110,6 +113,50 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set() self.top_k_reqs: Set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = \
self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()
self.min_tokens: List[int] = [0] * max_num_reqs
self.stop_token_ids: List[Set[int]] = [
set() for _ in range(max_num_reqs)
]
self.prompt_token_ids: Optional[torch.Tensor] = None
# req_index -> generator # req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own # NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary. # generator should not be included in the dictionary.
...@@ -133,6 +180,7 @@ class InputBatch: ...@@ -133,6 +180,7 @@ class InputBatch:
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[ self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
...@@ -157,6 +205,20 @@ class InputBatch: ...@@ -157,6 +205,20 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0: if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id) self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \
sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \
sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \
sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
# NOTE(woosuk): self.generators should not include the requests that # NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator. # do not have their own generator.
...@@ -179,6 +241,9 @@ class InputBatch: ...@@ -179,6 +241,9 @@ class InputBatch:
self.random_reqs.discard(req_id) self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id) self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id) self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None) self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None) self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id) self.prompt_logprob_reqs.discard(req_id)
...@@ -191,6 +256,9 @@ class InputBatch: ...@@ -191,6 +256,9 @@ class InputBatch:
self.random_reqs.clear() self.random_reqs.clear()
self.top_p_reqs.clear() self.top_p_reqs.clear()
self.top_k_reqs.clear() self.top_k_reqs.clear()
self.frequency_penalties_reqs.clear()
self.presence_penalties_reqs.clear()
self.repetition_penalties_reqs.clear()
self.generators.clear() self.generators.clear()
self.num_logprobs.clear() self.num_logprobs.clear()
self.prompt_logprob_reqs.clear() self.prompt_logprob_reqs.clear()
...@@ -224,6 +292,8 @@ class InputBatch: ...@@ -224,6 +292,8 @@ class InputBatch:
# block_table_cpu. # block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[ self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index] last_req_index]
self.num_prompt_tokens[empty_index] = \
self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[ self.block_table_cpu[empty_index] = self.block_table_cpu[
...@@ -232,6 +302,15 @@ class InputBatch: ...@@ -232,6 +302,15 @@ class InputBatch:
last_req_index] last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \
self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \
self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \
self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \
self.stop_token_ids[last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
...@@ -241,6 +320,7 @@ class InputBatch: ...@@ -241,6 +320,7 @@ class InputBatch:
def make_sampling_metadata( def make_sampling_metadata(
self, self,
req_id_output_token_ids: Dict[str, List[int]],
skip_copy: bool = False, skip_copy: bool = False,
) -> SamplingMetadata: ) -> SamplingMetadata:
if not skip_copy: if not skip_copy:
...@@ -250,6 +330,37 @@ class InputBatch: ...@@ -250,6 +330,37 @@ class InputBatch:
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_( self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
output_token_ids: List[List[int]] = []
for req_id in self.req_ids[:self.num_reqs]:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
return SamplingMetadata( return SamplingMetadata(
temperature=self.temperature[:self.num_reqs], temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
...@@ -260,8 +371,33 @@ class InputBatch: ...@@ -260,8 +371,33 @@ class InputBatch:
no_top_k=self.no_top_k, no_top_k=self.no_top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=self.prompt_token_ids,
frequency_penalties=self.frequency_penalties[:self.num_reqs],
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
output_token_ids=output_token_ids,
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = (
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
return len(self.req_id_to_index) return len(self.req_id_to_index)
...@@ -282,6 +418,12 @@ class InputBatch: ...@@ -282,6 +418,12 @@ class InputBatch:
def no_top_k(self) -> bool: def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0 return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property @property
def max_num_logprobs(self) -> int: def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0 return max(self.num_logprobs.values()) if self.num_logprobs else 0
......
...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -79,8 +79,14 @@ class GPUModelRunner: ...@@ -79,8 +79,14 @@ class GPUModelRunner:
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self.mm_input_mapper = MMInputMapperClient(self.model_config) # NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher()
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size self.encoder_cache_size = self.scheduler_config.encoder_cache_size
...@@ -99,6 +105,7 @@ class GPUModelRunner: ...@@ -99,6 +105,7 @@ class GPUModelRunner:
max_num_blocks_per_req=self.max_num_blocks_per_req, max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
) )
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
...@@ -377,7 +384,12 @@ class GPUModelRunner: ...@@ -377,7 +384,12 @@ class GPUModelRunner:
or scheduler_output.scheduled_resumed_reqs): or scheduler_output.scheduled_resumed_reqs):
skip_copy = False skip_copy = False
# Create the sampling metadata. # Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) req_id_output_token_ids: Dict[str, List[int]] = \
{req_id: req.output_token_ids \
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy)
return sampling_metadata return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
...@@ -628,11 +640,6 @@ class GPUModelRunner: ...@@ -628,11 +640,6 @@ class GPUModelRunner:
mm_registry=self.mm_registry, mm_registry=self.mm_registry,
) )
dummy_mm_data = dummy_request_data.multi_modal_data dummy_mm_data = dummy_request_data.multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple. # modality even when it supports multiple.
...@@ -648,8 +655,39 @@ class GPUModelRunner: ...@@ -648,8 +655,39 @@ class GPUModelRunner:
# (e.g, multiple images) for a single request, therefore here we # (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1 # always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately. # they are scheduled to be processed separately.
# Case when models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we need to "unbatch"
# and take the first item in each batched tensor.
# TODO (ywang96): This is somewhat hacky. Refactor this to be
# consistent with the other case.
if isinstance(dummy_mm_data, MultiModalKwargs):
dummy_mm_kwargs = {
k: v[0].unsqueeze(0)
for k, v in dummy_mm_data.items()
}
# Case when models have dummy data explicitly defined as
# `MultiModalDataDict`, so they need to be processed through input
# mapper.
else:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_dummy_mm_data(
dummy_mm_data)
mm_kwargs_list = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=mm_hashes,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)
# Take the first `MultiModalKwargs`
dummy_mm_kwargs = mm_kwargs_list[0]
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs[0]] * max_num_mm_items) [dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device) batched_dummy_mm_inputs, device=self.device)
......
...@@ -202,7 +202,6 @@ class Worker: ...@@ -202,7 +202,6 @@ class Worker:
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.rank == 0 else None
return output
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.profiler is None: if self.profiler is None:
......
...@@ -114,8 +114,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -114,8 +114,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
def __init__(self, use_mrope: bool): def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope self.use_mrope = use_mrope
self.input_tokens: List[int] = [] self.input_tokens: List[int] = []
self.input_positions: Optional[ self.input_positions: List[int] = []
List[int]] = [] if not self.use_mrope else None
self.token_type_ids: Optional[List[int]] = [] self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = [] self.seq_lens: List[int] = []
self.query_lens: List[int] = [] self.query_lens: List[int] = []
...@@ -130,9 +129,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -130,9 +129,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.multi_modal_placeholder_maps: Dict[ self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict( str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap) MultiModalPlaceholderMap)
self.input_mrope_positions: Optional[List[List[int]]] = [ self.input_mrope_positions: List[List[int]] = [[]
[] for _ in range(3) for _ in range(3)]
] if self.use_mrope else None
def __init__(self, def __init__(self,
runner: "CPUModelRunner", runner: "CPUModelRunner",
...@@ -167,7 +165,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -167,7 +165,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
device="cpu") device="cpu")
input_positions = torch.tensor( input_positions = torch.tensor(
input_data.input_positions input_data.input_positions
if not input_data.use_mrope else input_data.input_mrope_positions, if not any(input_data.input_mrope_positions) else
input_data.input_mrope_positions,
dtype=torch.long, dtype=torch.long,
device="cpu") device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids, token_type_ids = torch.tensor(input_data.token_type_ids,
...@@ -236,7 +235,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -236,7 +235,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
block_table = block_table[start_block:] block_table = block_table[start_block:]
# For MRotaryEmbedding # For MRotaryEmbedding
if data.input_positions is None: if seq_data.mrope_position_delta is not None:
next_pos = MRotaryEmbedding.get_next_input_positions( next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta, seq_data.mrope_position_delta,
context_len, context_len,
...@@ -309,8 +308,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -309,8 +308,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
data.slot_mapping.extend(slot_mapping) data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input # The MROPE positions are prepared in _compute_multi_modal_input
if data.input_positions is not None: data.input_positions.extend(token_positions)
data.input_positions.extend(token_positions)
if data.token_type_ids is not None: if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else []) data.token_type_ids.extend(token_types if token_types else [])
......
...@@ -338,9 +338,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -338,9 +338,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu", device="cpu",
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
......
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
...@@ -22,7 +22,8 @@ from vllm.attention.backends.utils import CommonAttentionState ...@@ -22,7 +22,8 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1416,8 +1417,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1416,8 +1417,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger.info("Capturing cudagraphs for decoding. This may lead to " logger.info("Capturing cudagraphs for decoding. This may lead to "
"unexpected consequences if the model is not static. To " "unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or " "run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.") "use '--enforce-eager' in the CLI. "
logger.info("If out-of-memory error occurs during cudagraph capture," "If out-of-memory error occurs during cudagraph capture,"
" consider decreasing `gpu_memory_utilization` or " " consider decreasing `gpu_memory_utilization` or "
"switching to eager mode. You can also reduce the " "switching to eager mode. You can also reduce the "
"`max_num_seqs` as needed to decrease memory usage.") "`max_num_seqs` as needed to decrease memory usage.")
...@@ -1454,8 +1455,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1454,8 +1455,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
for batch_size in \ # Only rank 0 should print progress bar during capture
self.vllm_config.compilation_config.capture_sizes: capture_sizes = (
tqdm(
self.vllm_config.compilation_config.capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes)
for batch_size in capture_sizes:
attn_metadata = ( attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch( self.attn_state.graph_capture_get_metadata_for_batch(
batch_size, batch_size,
......
...@@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -406,8 +406,9 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
if not cont: if not cont:
break break
def _final_process_outputs(self, model_input: StatefulModelInput, def _final_process_outputs(
output_proc_callback: Optional[Callable]): self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None has_async_callback = output_proc_callback is not None
...@@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -594,8 +595,8 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# should be [SamplerOutput] # should be [SamplerOutput]
return output return output
def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_queries): num_seqs: Optional[int], num_queries: int):
assert sampling_metadata.num_prompts == 0 assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries assert len(sampling_metadata.seq_groups) == num_queries
...@@ -820,7 +821,7 @@ def _pythonize_sampler_output( ...@@ -820,7 +821,7 @@ def _pythonize_sampler_output(
for sgdx, (seq_group, for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)): sample_result) in enumerate(zip(seq_groups, samples_list)):
# Reminder: Please update docs/source/usage/compatibility_matrix.rst # Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
# (Check for Guided Decoding) # (Check for Guided Decoding)
if seq_group.sampling_params.logits_processors: if seq_group.sampling_params.logits_processors:
...@@ -850,13 +851,13 @@ def _pythonize_sampler_output( ...@@ -850,13 +851,13 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids = sample_result next_token_ids = sample_result
parent_ids = [0] parent_ids = [0]
seq_outputs: List[SequenceOutput]
if cache is not None: if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \ completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object() cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear() completion_seq_group_output.samples.clear()
seq_outputs: List[ seq_outputs = completion_seq_group_output.samples
SequenceOutput] = completion_seq_group_output.samples
else: else:
seq_outputs = [] seq_outputs = []
......
...@@ -91,6 +91,10 @@ class PoolingModelRunner( ...@@ -91,6 +91,10 @@ class PoolingModelRunner(
] ]
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True) model_forward_start = torch.cuda.Event(enable_timing=True)
...@@ -110,7 +114,8 @@ class PoolingModelRunner( ...@@ -110,7 +114,8 @@ class PoolingModelRunner(
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**cross_enc_kwargs) **cross_enc_kwargs,
**seqlen_agnostic_kwargs)
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
......
...@@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario( ...@@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario(
a supported scenario. a supported scenario.
''' '''
# Reminder: Please update docs/source/usage/compatibility_matrix.rst # Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
if enc_dec_mr.cache_config.enable_prefix_caching: if enc_dec_mr.cache_config.enable_prefix_caching:
......
...@@ -485,7 +485,7 @@ class WorkerWrapperBase: ...@@ -485,7 +485,7 @@ class WorkerWrapperBase:
self.worker = worker_class(*args, **kwargs) self.worker = worker_class(*args, **kwargs)
assert self.worker is not None assert self.worker is not None
def execute_method(self, method, *args, **kwargs): def execute_method(self, method: str, *args, **kwargs):
try: try:
target = self if self.worker is None else self.worker target = self if self.worker is None else self.worker
executor = getattr(target, method) executor = getattr(target, method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment