Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -18,7 +18,7 @@ else: ...@@ -18,7 +18,7 @@ else:
def init_tokenizer_from_configs(model_config: ModelConfig, def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
lora_config: LoRAConfig): lora_config: Optional[LoRAConfig]):
init_kwargs = dict(tokenizer_id=model_config.tokenizer, init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config), enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs, max_num_seqs=scheduler_config.max_num_seqs,
......
...@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC): ...@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
@abstractmethod @abstractmethod
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
...@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC): ...@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
......
...@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
...@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
...@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
...@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
...@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
...@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
......
...@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request) tokenizer = self.get_lora_tokenizer(lora_request)
...@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request) tokenizer = await self.get_lora_tokenizer_async(lora_request)
......
...@@ -153,6 +153,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -153,6 +153,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8": torch.uint8, "fp8": torch.uint8,
"fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = { TORCH_DTYPE_TO_NUMPY_DTYPE = {
...@@ -411,6 +412,11 @@ async def merge_async_iterators( ...@@ -411,6 +412,11 @@ async def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item. iterator that yields the item.
""" """
if len(iterators) == 1:
# Fast-path single iterator case.
async for item in iterators[0]:
yield 0, item
return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
...@@ -2142,20 +2148,53 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: ...@@ -2142,20 +2148,53 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0) ctx.destroy(linger=0)
def _check_multiproc_method(): def is_in_ray_actor():
if (cuda_is_initialized() """Check if we are in a Ray actor."""
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use " try:
"the `spawn` multiprocessing start method. Setting " import ray
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " return (ray.is_initialized()
"See https://docs.vllm.ai/en/latest/getting_started/" and ray.get_runtime_context().get_actor_id() is not None)
"troubleshooting.html#python-multiprocessing " except ImportError:
"for more information.") return False
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reason = None
if cuda_is_initialized():
reason = "CUDA is initialized"
elif is_in_ray_actor():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
reason = "In a Ray actor and can only be spawned"
if reason is not None:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s", reason)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context(): def get_mp_context():
_check_multiproc_method() """Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method) return multiprocessing.get_context(mp_method)
...@@ -2355,3 +2394,51 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: ...@@ -2355,3 +2394,51 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
obj[key1] = v2 obj[key1] = v2
else: else:
obj.pop(key1, None) obj.pop(key1, None)
@contextlib.contextmanager
def cprofile_context(save_file: Optional[str] = None):
"""Run a cprofile
Args:
save_file: path to save the profile result. "1" or
None will result in printing to stdout.
"""
import cProfile
prof = cProfile.Profile()
prof.enable()
try:
yield
finally:
prof.disable()
if save_file and save_file != "1":
prof.dump_stats(save_file)
else:
prof.print_stats(sort="cumtime")
def cprofile(save_file: Optional[str] = None, enabled: bool = True):
"""Decorator to profile a Python method using cProfile.
Args:
save_file: Path to save the profile result.
If "1", None, or "", results will be printed to stdout.
enabled: Set to false to turn this into a no-op
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
if not enabled:
# If profiling is disabled, just call the function directly.
return func(*args, **kwargs)
with cprofile_context(save_file):
return func(*args, **kwargs)
return wrapper
return decorator
...@@ -6,17 +6,18 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,17 +6,18 @@ from typing import TYPE_CHECKING, Any, Optional
import numpy as np import numpy as np
import torch import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
...@@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
...@@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade: if not attn_metadata.use_cascade:
...@@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
return output return output
...@@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
) )
return output return output
...@@ -391,6 +412,9 @@ def cascade_attention( ...@@ -391,6 +412,9 @@ def cascade_attention(
block_table: torch.Tensor, block_table: torch.Tensor,
common_prefix_len: int, common_prefix_len: int,
fa_version: int, fa_version: int,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window. # TODO: Support sliding window.
...@@ -402,6 +426,7 @@ def cascade_attention( ...@@ -402,6 +426,7 @@ def cascade_attention(
assert common_prefix_len % block_size == 0 assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0 assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix. # Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func( prefix_output, prefix_lse = flash_attn_varlen_func(
...@@ -419,8 +444,16 @@ def cascade_attention( ...@@ -419,8 +444,16 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query. # Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func( suffix_output, suffix_lse = flash_attn_varlen_func(
q=query, q=query,
...@@ -437,6 +470,12 @@ def cascade_attention( ...@@ -437,6 +470,12 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
......
...@@ -195,8 +195,8 @@ from vllm import _custom_ops as ops ...@@ -195,8 +195,8 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
...@@ -212,7 +212,7 @@ except ImportError: ...@@ -212,7 +212,7 @@ except ImportError:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
......
...@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size) return (num_blocks, block_size, num_kv_heads * head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
...@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size], kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads, head_size]) [num_blocks, block_size, num_kv_heads * head_size])
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
...@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size) query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0: if kv_cache[0].numel() > 0:
...@@ -192,10 +190,10 @@ def write_to_kv_cache( ...@@ -192,10 +190,10 @@ def write_to_kv_cache(
""" Write the key and values to the KV cache. """ Write the key and values to the KV cache.
Args: Args:
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads, head_size] k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads, head_size] v_cache = [num_blocks, block_size, num_kv_heads * head_size]
""" """
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
...@@ -203,6 +201,5 @@ def write_to_kv_cache( ...@@ -203,6 +201,5 @@ def write_to_kv_cache(
key_cache = key_cache.flatten(0, 1) key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1) value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()
key_cache.index_copy_(0, slot_mapping, key) key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value) value_cache.index_copy_(0, slot_mapping, value)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm""" """Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import ( ...@@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import (
logger = init_logger(__name__) logger = init_logger(__name__)
class ROCmAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
...@@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend): ...@@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_ATTN_VLLM_V1" return "TRITON_ATTN_VLLM_V1"
@staticmethod @staticmethod
def get_impl_cls() -> type["ROCmAttentionImpl"]: def get_impl_cls() -> type["TritonAttentionImpl"]:
return ROCmAttentionImpl return TritonAttentionImpl
@staticmethod @staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
...@@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend): ...@@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend):
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
class ROCmAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def __init__( def __init__(
self, self,
...@@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl): ...@@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl):
) -> None: ) -> None:
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError(
"ROCmAttention does not support block-sparse attention.") "TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
...@@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl): ...@@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes: if head_size not in support_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by ROCmAttention. " f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"ROCmAttentionImpl") "TritonAttentionImpl")
def forward( def forward(
self, self,
......
...@@ -7,8 +7,8 @@ from typing import Any, NamedTuple, Optional ...@@ -7,8 +7,8 @@ from typing import Any, NamedTuple, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheTensor) KVCacheSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int, ...@@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int,
def check_enough_kv_cache_memory(vllm_config: VllmConfig, def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int): available_memory: int):
""" """
Checks whether `available_memory` is enough for the KV cache to hold at Checks whether `available_memory` is enough for the KV cache to hold at
...@@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, ...@@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes. available_memory: Memory available for KV cache in bytes.
Raises: Raises:
...@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, ...@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
f"`max_model_len` when initializing the engine.") f"`max_model_len` when initializing the engine.")
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]]
assert all(
kv_cache_spec[layer_name] == layer_spec
for layer_name in layer_names_one_group[1:]), (
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec))
return kv_cache_groups
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
""" """
Whether all layers in the given KVCacheSpec have the same type of KV cache. Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args: Args:
kv_cache_spec: The KVCacheSpec of the model kv_cache_spec: The kv cache spec of each attention layer in the model
Returns: Returns:
True if all layers have the same type, False otherwise. True if all layers have the same type, False otherwise.
...@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: ...@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int, available_memory: int) -> KVCacheConfig:
num_layers: int) -> KVCacheConfig:
""" """
Generates the KV cache configuration for a model with one type of KV cache. Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers. Divide the available memory equally among all layers.
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes. available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns: Returns:
The generated KVCacheConfig The generated KVCacheConfig
...@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ...@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1 assert len(page_sizes) == 1
page_size = page_sizes.pop() page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // num_layers) num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0) num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None: if vllm_config.cache_config.num_gpu_blocks_override is not None:
...@@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ...@@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
max_model_len_str, max_concurrency) max_model_len_str, max_concurrency)
per_layer_size = page_size * num_blocks per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
kv_cache_config = KVCacheConfig( kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, num_blocks=num_blocks,
...@@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ...@@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
layer_name: KVCacheTensor(size=per_layer_size) layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec for layer_name in kv_cache_spec
}, },
groups=[[layer_name for layer_name in kv_cache_spec]], kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
kv_cache_spec=kv_cache_spec) grouped_layer_names),
)
return kv_cache_config return kv_cache_config
def get_kv_cache_configs(vllm_config: VllmConfig, def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_specs: list[KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> list[KVCacheConfig]: available_memory: int) -> KVCacheConfig:
""" """
Generates the KV cache configuration for a model Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache. TODO: support hybrid models with more than one type of KV cache.
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_specs: The kv cache specs of the model kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes. available_memory: Memory available for KV cache in bytes.
Returns: Returns:
The generated KVCacheConfigs The generated KVCacheConfigs
""" """
# Use the max number of layers to conservatively determine check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
# the number of blocks. if is_kv_cache_type_uniform(kv_cache_spec):
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs) # KV cache of all layers are the same, which is true for
kv_cache_configs = [] # most models. Allocate the same amount of memory for
for kv_cache_spec in kv_cache_specs: # each layer.
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory) available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for raise NotImplementedError
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append( def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, """
available_memory, Make the KV cache configurations for each worker consistent, so that all
num_layers)) workers can be controlled by the same KVCacheManager.
else: This function verifies that the layer group of each worker are the same,
raise NotImplementedError and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks = min(kv_cache_config.num_blocks
for kv_cache_config in kv_cache_configs)
for kv_cache_config in kv_cache_configs:
kv_cache_config.num_blocks = min_num_blocks
return kv_cache_configs return kv_cache_configs
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
class SchedulerInterface(ABC):
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> "EngineCoreOutputs":
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
"""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
raise NotImplementedError
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError
...@@ -13,8 +13,10 @@ from vllm.logger import init_logger ...@@ -13,8 +13,10 @@ from vllm.logger import init_logger
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.interface import SchedulerInterface
SchedulerOutput) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager ...@@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__) logger = init_logger(__name__)
class Scheduler: class Scheduler(SchedulerInterface):
def __init__( def __init__(
self, self,
...@@ -602,7 +604,7 @@ class Scheduler: ...@@ -602,7 +604,7 @@ class Scheduler:
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request) stopped = check_stop(request, self.max_model_len)
if stopped: if stopped:
self._free_request(request) self._free_request(request)
break break
...@@ -648,25 +650,6 @@ class Scheduler: ...@@ -648,25 +650,6 @@ class Scheduler:
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(),
) )
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
...@@ -715,17 +698,9 @@ class Scheduler: ...@@ -715,17 +698,9 @@ class Scheduler:
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running) return len(self.waiting) + len(self.running)
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def has_finished_requests(self) -> bool: def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0 return len(self.finished_req_ids) > 0
def has_requests(self):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
def get_num_unscheduled_requests(self) -> int: def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor.""" """Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
......
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
...@@ -4,6 +4,7 @@ import asyncio ...@@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
...@@ -24,7 +25,8 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams ...@@ -24,7 +25,8 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
...@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient): ...@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient):
) -> asyncio.Queue[RequestOutput]: ) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request. # Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue() queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
# 2) Fan out child requests (for n>1) # Convert Input --> Request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 3) Convert Input --> Request. if n == 1:
request = self.processor.process_inputs(request_id, prompt, params, await self._add_request(request, None, 0, queue)
arrival_time, lora_request, return queue
trace_headers,
prompt_adapter_request,
priority)
# 4) Add the request to OutputProcessor (this process). # Fan out child requests (for n>1).
self.output_processor.add_request(request, parent_req, idx, queue) parent_request = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
return queue
# 5) Add the EngineCoreRequest to EngineCore (separate process). async def _add_request(self, request: EngineCoreRequest,
await self.engine_core.add_request_async(request) parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
if self.log_requests: # Add the request to OutputProcessor (this process).
logger.info("Added request %s.", request_id) self.output_processor.add_request(request, parent_req, index, queue)
return queue # Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you # TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion # can do with LLM.generate. So that for multi-prompt completion
...@@ -398,7 +411,10 @@ class AsyncLLM(EngineClient): ...@@ -398,7 +411,10 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async() await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
......
...@@ -21,9 +21,10 @@ from vllm.transformers_utils.config import ( ...@@ -21,9 +21,10 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx) zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
from vllm.v1.core.scheduler import Scheduler as V1Scheduler unify_kv_cache_configs)
from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
...@@ -120,15 +121,27 @@ class EngineCore: ...@@ -120,15 +121,27 @@ class EngineCore:
# memory can be allocated for kv cache. # memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory() available_gpu_memory = self.model_executor.determine_available_memory()
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size # Get the kv cache tensor size
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, kv_cache_configs = [
available_gpu_memory) get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
num_gpu_blocks_set = set(config.num_blocks available_gpu_memory_one_worker)
for config in kv_cache_configs) for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
assert len(num_gpu_blocks_set) == 1, ( zip(kv_cache_specs, available_gpu_memory)
f"num_gpu_blocks need to be the same across workers, " ]
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop() # Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0 num_cpu_blocks = 0
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
...@@ -179,16 +192,6 @@ class EngineCore: ...@@ -179,16 +192,6 @@ class EngineCore:
scheduler_stats=self.scheduler.make_stats(), scheduler_stats=self.scheduler.make_stats(),
) )
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output) output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore scheduler_output, output) # type: ignore
......
...@@ -212,9 +212,9 @@ class BackgroundResources: ...@@ -212,9 +212,9 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object.""" circular reference back to the client object."""
ctx: Union[zmq.Context] = None ctx: zmq.Context
output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -20,6 +21,7 @@ from vllm.sampling_params import SamplingParams ...@@ -20,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
...@@ -178,25 +180,34 @@ class LLMEngine: ...@@ -178,25 +180,34 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
# 1) Fan out child requests (for n>1) # Process raw inputs into the request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 2) Process raw inputs into the request. if n == 1:
request = self.processor.process_inputs(request_id, prompt, params, # Make a new RequestState and queue.
arrival_time, lora_request, self.output_processor.add_request(request, None, 0)
trace_headers, # Add the request to EngineCore.
prompt_adapter_request, self.engine_core.add_request(request)
priority) return
# 3) Make a new RequestState and queue. # Fan out child requests (for n>1).
self.output_processor.add_request(request, parent_req, idx) parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# 3) Add the request to EngineCore. # Make a new RequestState and queue.
self.engine_core.add_request(request) self.output_processor.add_request(child_request, parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]: def step(self) -> list[RequestOutput]:
...@@ -226,7 +237,7 @@ class LLMEngine: ...@@ -226,7 +237,7 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_prefix_cache(self): def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from copy import copy from copy import copy
from typing import Optional, Union from typing import Optional
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
...@@ -43,16 +42,6 @@ class ParentRequest: ...@@ -43,16 +42,6 @@ class ParentRequest:
self.max_num_generation_tokens = 0 self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None self.cached_child_sampling_params = None
@classmethod
def from_params(
cls,
request_id: str,
params: Union[SamplingParams, PoolingParams],
) -> Optional['ParentRequest']:
if not isinstance(params, SamplingParams) or params.n == 1:
return None
return cls(request_id, params)
def _get_child_sampling_params( def _get_child_sampling_params(
self, self,
index: int, index: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment