Unverified Commit 6c5af09b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1] Implement vLLM V1 [1/N] (#9289)

parent 3ddbe255
...@@ -17,6 +17,7 @@ logger = init_logger(__name__) ...@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class _Backend(enum.Enum): class _Backend(enum.Enum):
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
...@@ -110,6 +111,10 @@ def get_attn_backend( ...@@ -110,6 +111,10 @@ def get_attn_backend(
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend)
return FlashAttentionBackend return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS: if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.") logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401 from vllm.attention.backends.xformers import ( # noqa: F401
...@@ -215,6 +220,9 @@ def which_attn_to_use( ...@@ -215,6 +220,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend) logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH return _Backend.ROCM_FLASH
if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80): if not current_platform.has_device_capability(80):
......
...@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union ...@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
import zmq import zmq
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig] SchedulerConfig, LoRAConfig]
...@@ -136,14 +141,16 @@ class MQLLMEngine: ...@@ -136,14 +141,16 @@ class MQLLMEngine:
executor_class = LLMEngine._get_executor_cls(engine_config) executor_class = LLMEngine._get_executor_cls(engine_config)
return cls( use_async_sockets = (engine_config.model_config.use_async_output_proc
ipc_path=ipc_path, and not VLLM_USE_V1)
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(), return cls(ipc_path=ipc_path,
executor_class=executor_class, use_async_sockets=use_async_sockets,
log_requests=not engine_args.disable_log_requests, **engine_config.to_dict(),
log_stats=not engine_args.disable_log_stats, executor_class=executor_class,
usage_context=usage_context) log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
def start(self): def start(self):
try: try:
......
...@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, ...@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from tqdm import tqdm from tqdm import tqdm
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
...@@ -31,6 +31,11 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup ...@@ -31,6 +31,11 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -68,6 +68,7 @@ if TYPE_CHECKING: ...@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = [] VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -450,6 +451,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -450,6 +451,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS": "VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","), "VLLM_DISABLED_KERNELS"].split(","),
# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module): ...@@ -48,14 +48,15 @@ class LogitsProcessor(nn.Module):
self, self,
lm_head: VocabParallelEmbedding, lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
logits = hidden_states logits = hidden_states
else: else:
hidden_states = _prune_hidden_states(hidden_states, if sampling_metadata is not None:
sampling_metadata) hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias) logits = self._get_logits(hidden_states, lm_head, embedding_bias)
...@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module): ...@@ -69,7 +70,8 @@ class LogitsProcessor(nn.Module):
logits *= self.scale logits *= self.scale
# Apply logits processors (if any). # Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata) if sampling_metadata is not None:
logits = _apply_logits_processors(logits, sampling_metadata)
return logits return logits
......
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup) Sequence, SequenceGroup)
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer from .tokenizer import AnyTokenizer
from .tokenizer_group import BaseTokenizerGroup from .tokenizer_group import BaseTokenizerGroup
...@@ -161,167 +163,3 @@ class Detokenizer: ...@@ -161,167 +163,3 @@ class Detokenizer:
seq.output_text += new_decoded_token_text seq.output_text += new_decoded_token_text
return len(new_decoded_token_text) return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
from typing import List, Optional, Tuple
from .tokenizer import AnyTokenizer
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
skip_special_tokens=skip_special_tokens)
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if 0 <= new_token_id < len(tokenizer):
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
else:
new_tokens = [""]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return new_tokens, "", prefix_offset, read_offset
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.vllm_flash_attn import flash_attn_varlen_func
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "flash-attn-vllm-v1"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@dataclass
class FlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return torch.empty_like(query)
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
from typing import Dict, List, Optional
import numpy as np
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
# reduce the overhead of updating free_block_ids and ref_cnts for each
# request every step (at the cost of some memory waste).
# NOTE(woosuk): This is different from the "lookahead" slots since this
# does not guarantee that the request always has N empty blocks. After
# the request gets N empty blocks, it starts to use the blocks without
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
self.free_block_ids = list(range(num_gpu_blocks))
self.req_to_block_ids: Dict[str, List[int]] = {}
self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32)
def get_computed_blocks(self, request: Request) -> List[int]:
if not self.enable_caching:
# No prefix caching.
return []
# TODO(woosuk): Implement hash-based caching.
return []
def append_slots(
self,
request: Request,
num_tokens: int,
) -> Optional[List[int]]:
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size)
req_block_ids = self.req_to_block_ids[request.request_id]
if num_required_blocks <= len(req_block_ids):
# No new block is needed.
return []
num_new_blocks = num_required_blocks - len(req_block_ids)
num_free_blocks = len(self.free_block_ids)
if num_new_blocks > num_free_blocks:
# Cannot allocate new blocks.
return None
# Allocate new blocks.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
num_free_blocks)
new_block_ids = self._get_new_blocks(num_new_blocks)
req_block_ids.extend(new_block_ids)
self.ref_cnts[new_block_ids] += 1
return new_block_ids
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_block_ids: List[int],
) -> Optional[List[int]]:
num_required_blocks = cdiv(num_tokens, self.block_size)
num_free_blocks = len(self.free_block_ids)
if num_required_blocks > num_free_blocks:
# Cannot allocate new blocks.
return None
num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks,
num_free_blocks)
new_block_ids = self._get_new_blocks(num_new_blocks)
block_ids = computed_block_ids + new_block_ids
self.req_to_block_ids[request.request_id] = block_ids
self.ref_cnts[block_ids] += 1
return new_block_ids
def free(self, request: Request) -> None:
block_ids = self.req_to_block_ids.pop(request.request_id)
self.ref_cnts[block_ids] -= 1
for block_id in block_ids:
ref_cnt = self.ref_cnts[block_id]
if ref_cnt == 0:
self.free_block_ids.append(block_id)
def _get_new_blocks(self, num_blocks: int) -> List[int]:
assert num_blocks <= len(self.free_block_ids)
new_block_ids = self.free_block_ids[-num_blocks:]
self.free_block_ids = self.free_block_ids[:-num_blocks]
return new_block_ids
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the block space manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=True)
self.block_size = self.cache_config.block_size
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
# req_id -> Request
self.requests: Dict[str, Request] = {}
# Priority queues for requests.
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: Set[str] = set()
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> RunningRequestData
self.running_reqs_data: Dict[str, RunningRequestData] = {}
def schedule(self) -> "SchedulerOutput":
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
scheduled_running_reqs: List[Request] = []
preempted_reqs: List[Request] = []
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump forward" optimization in the future.
req_to_new_block_ids: Dict[str, List[int]] = {}
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running):
if token_budget == 0:
break
request = self.running[req_index]
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
while True:
new_block_ids = self.kv_cache_manager.append_slots(
request, num_new_tokens)
if new_block_ids is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
break
else:
# The request can be scheduled.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = new_block_ids
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
break
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
if len(self.running) == self.max_num_running_reqs:
break
if token_budget == 0:
break
request = self.waiting[0]
# Get already-cached tokens.
computed_block_ids = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_block_ids) * self.block_size
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
new_block_ids = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_block_ids)
if new_block_ids is None:
# The request cannot be scheduled.
break
request.num_computed_tokens = num_computed_tokens
self.waiting.popleft()
self.running.append(request)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
req_to_new_block_ids[request.request_id] = (
computed_block_ids + new_block_ids)
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
for req in scheduled_new_reqs
]
resumed_reqs_data = [
ResumedRequestData.from_request(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_running_request_data(
req, req_to_new_block_ids[req.request_id],
req.num_computed_tokens) for req in scheduled_running_reqs
]
preempted_req_ids = {req.request_id for req in preempted_reqs}
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_resumed_reqs=resumed_reqs_data,
scheduled_running_reqs=running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
)
self.finished_req_ids = set()
return scheduler_output
def _make_running_request_data(
self,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
# OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self.running_reqs_data:
req_data = self.running_reqs_data[request.request_id]
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = RunningRequestData.from_request(request, new_block_ids,
num_computed_tokens)
self.running_reqs_data[request.request_id] = req_data
return req_data
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[Tuple[Request, int]]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
sampled: List[Tuple[Request, int]] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.
# Check if the request is finished.
stopped = self._check_stop(request)
if stopped:
continue
new_running.append(request)
self.running = new_running
return sampled
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: RequestStatus,
) -> None:
"""Handles the finish signal from outside the scheduler.
For example, the API server can abort a request when the client
disconnects.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
else:
self.waiting.remove(request)
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional[MultiModalDataDict]
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.inputs["prompt_token_ids"],
prompt=request.inputs.get("prompt"),
multi_modal_data=request.inputs.get("multi_modal_data"),
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class ResumedRequestData:
req_id: str
block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
block_ids: List[int],
num_computed_tokens: int,
) -> "ResumedRequestData":
return cls(
req_id=request.request_id,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class RunningRequestData:
req_id: str
new_block_ids: List[int]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
new_block_ids: List[int],
num_computed_tokens: int,
) -> "RunningRequestData":
return cls(
req_id=request.request_id,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
@dataclass
class SchedulerOutput:
scheduled_new_reqs: List[NewRequestData]
scheduled_resumed_reqs: List[ResumedRequestData]
scheduled_running_reqs: List[RunningRequestData]
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
preempted_req_ids: Set[str]
finished_req_ids: Set[str]
This diff is collapsed.
import os
from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class GPUExecutor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.worker = self._create_worker()
self.worker.initialize()
self.worker.load_model()
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Worker:
"""Return worker init args for a given rank."""
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
observability_config=self.observability_config,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.worker.initialize_cache(num_gpu_blocks)
self.worker.compile_or_warm_up_model()
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
output = self.worker.execute_model(scheduler_output)
return output
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]
# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: List[str]
# req_id -> index
req_id_to_index: Dict[str, int]
# [num_reqs]
sampled_token_ids_cpu: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor]
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
class Request:
def __init__(
self,
request_id: str,
inputs: "DecoderOnlyInputs",
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = inputs
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.status = RequestStatus.WAITING
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self.output_text = ""
self.num_computed_tokens = 0
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self.output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status)
class RequestStatus(enum.IntEnum):
"""Status of a sequence."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
# Note: anything after PREEMPTED (2) will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status > RequestStatus.PREEMPTED
@staticmethod
def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
return _FINISHED_REASON_MAP.get(status)
# Mapping of finished statuses to their finish reasons.
# NOTE: The ignored sequences are the sequences whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
RequestStatus.FINISHED_STOPPED: "stop",
RequestStatus.FINISHED_LENGTH_CAPPED: "length",
RequestStatus.FINISHED_ABORTED: "abort",
RequestStatus.FINISHED_IGNORED: "length",
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment