Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is None, ValueError( if blocksparse_params is not None:
"ROCFlashAttention does not support blocksparse attention.") raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is None, ValueError( if blocksparse_params is not None:
"Torch SPDA does not support block-sparse attention.") raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list: for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data, self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled) self.input_builder.chunked_prefill_enabled)
...@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device = self.runner.device device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens) max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
...@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size num_decode_tokens = batch_size
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
......
...@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
assert blocksparse_params is None, ValueError( if blocksparse_params is not None:
"XFormer does not support block-sparse attention.") raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"XFormers does not support attention logits soft capping.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -34,6 +34,7 @@ class Attention(nn.Module): ...@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -82,7 +83,7 @@ class Attention(nn.Module): ...@@ -82,7 +83,7 @@ class Attention(nn.Module):
impl_cls = attn_backend.get_impl_cls() impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params) blocksparse_params, logits_soft_cap)
def forward( def forward(
self, self,
......
...@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple ...@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
...@@ -31,7 +34,7 @@ class PagedAttention: ...@@ -31,7 +34,7 @@ class PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 192, 256] return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
...@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
"AquilaModel", "AquilaModel",
"AquilaForCausalLM", "AquilaForCausalLM",
"DeepseekV2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
...@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [ ...@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [
"Phi3ForCausalLM", "Phi3ForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"MixtralForCausalLM", "MixtralForCausalLM",
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
] ]
...@@ -195,13 +201,17 @@ class ModelConfig: ...@@ -195,13 +201,17 @@ class ModelConfig:
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None: if quant_cfg is None:
# compress-tensors uses a "compression_config" key # compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None) quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg return quant_cfg
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm","awq"] rocm_supported_quantization = ["gptq", "squeezellm","awq"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -240,9 +250,7 @@ class ModelConfig: ...@@ -240,9 +250,7 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if (self.quantization if self.quantization not in optimized_quantization_methods:
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors")):
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
...@@ -281,6 +289,10 @@ class ModelConfig: ...@@ -281,6 +289,10 @@ class ModelConfig:
raise ValueError( raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.") "BitAndBytes quantization with TP or PP is not supported yet.")
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
raise ValueError(
"BitAndBytes with enforce_eager = False is not supported yet.")
def get_hf_config_sliding_window(self) -> Optional[int]: def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.""" """Get the sliding window size, or None if disabled."""
...@@ -590,9 +602,11 @@ class LoadConfig: ...@@ -590,9 +602,11 @@ class LoadConfig:
mainly for profiling. mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for "tensorizer" will use CoreWeave's tensorizer library for
fast weight loading. fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
...@@ -716,7 +730,7 @@ class ParallelConfig: ...@@ -716,7 +730,7 @@ class ParallelConfig:
backend) backend)
self._verify_args() self._verify_args()
self.rank = 0 self.rank: int = 0
@property @property
def use_ray(self) -> bool: def use_ray(self) -> bool:
...@@ -842,6 +856,7 @@ class SchedulerConfig: ...@@ -842,6 +856,7 @@ class SchedulerConfig:
class DeviceConfig: class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
...@@ -892,6 +907,7 @@ class SpeculativeConfig: ...@@ -892,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
use_v2_block_manager: bool, use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int], speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
...@@ -1053,7 +1069,7 @@ class SpeculativeConfig: ...@@ -1053,7 +1069,7 @@ class SpeculativeConfig:
draft_parallel_config = ( draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config, target_parallel_config,
speculative_draft_tensor_parallel_size)) speculative_draft_tensor_parallel_size, draft_hf_config))
if num_speculative_tokens is None: if num_speculative_tokens is None:
raise ValueError( raise ValueError(
...@@ -1080,7 +1096,8 @@ class SpeculativeConfig: ...@@ -1080,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
) )
@staticmethod @staticmethod
...@@ -1121,15 +1138,23 @@ class SpeculativeConfig: ...@@ -1121,15 +1138,23 @@ class SpeculativeConfig:
@staticmethod @staticmethod
def create_draft_parallel_config( def create_draft_parallel_config(
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int] speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig: ) -> ParallelConfig:
"""Create a parallel config for use by the draft worker. """Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size. This is mostly a copy of the target parallel config, except the tp_size.
""" """
if speculative_draft_tensor_parallel_size is None: if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \ if draft_hf_config.model_type == "mlp_speculator":
target_parallel_config.tensor_parallel_size speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1: elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1 # TODO(wooyeon): allow tp values larger than 1
raise ValueError( raise ValueError(
...@@ -1166,6 +1191,7 @@ class SpeculativeConfig: ...@@ -1166,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool, disable_logprobs: bool,
disable_log_stats: bool,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1198,6 +1224,8 @@ class SpeculativeConfig: ...@@ -1198,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be determined. If set to False, log probabilities will be
returned. returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1212,6 +1240,7 @@ class SpeculativeConfig: ...@@ -1212,6 +1240,7 @@ class SpeculativeConfig:
self.typical_acceptance_sampler_posterior_alpha = \ self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args() self._verify_args()
...@@ -1281,7 +1310,7 @@ class LoRAConfig: ...@@ -1281,7 +1310,7 @@ class LoRAConfig:
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
def __post_init__(self): def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h # TODO: Increase the range of rank
possible_max_ranks = (8, 16, 32, 64) possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512) possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks: if self.max_lora_rank not in possible_max_ranks:
...@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len( ...@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models " "Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue " "model_max_length in the config. Please raise an issue "
"so we can investigate.") "so we can investigate.")
pass
else: else:
raise ValueError( msg = (
f"User-specified max_model_len ({max_model_len}) is greater " f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len " f"than the derived max_model_len ({max_len_key}="
f"({max_len_key}={derived_max_model_len} or model_max_length=" f"{derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead " f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the " "to incorrect model outputs or CUDA errors.")
"value is correct and within the model context size.") if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
logger.warning(
"%s Make sure the value is correct and within the "
"model context size.", msg)
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
return int(max_model_len) return int(max_model_len)
......
from pathlib import Path from pathlib import Path
from typing import Mapping, Optional from typing import Mapping, MutableMapping, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
...@@ -40,7 +40,7 @@ class HTTPConnection: ...@@ -40,7 +40,7 @@ class HTTPConnection:
raise ValueError("Invalid HTTP URL: A valid HTTP URL " raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.") "must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]: def _headers(self, **extras: str) -> MutableMapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response( def get_response(
......
...@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager): ...@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching: if self.enable_caching:
for seq in seq_group.seqs_dict.values(): for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq) self.compute_full_blocks_in_seq(seq)
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
...@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union ...@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -313,6 +312,7 @@ class Scheduler: ...@@ -313,6 +312,7 @@ class Scheduler:
# Sequence groups finished requests ids since last step iteration. # Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests # It lets the model know that any state associated with these requests
# can and must be released after the current step. # can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list() self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step # Time at previous scheduling step
self.prev_time = 0.0 self.prev_time = 0.0
...@@ -344,6 +344,16 @@ class Scheduler: ...@@ -344,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID. """Aborts a sequence group with the given ID.
...@@ -374,6 +384,7 @@ class Scheduler: ...@@ -374,6 +384,7 @@ class Scheduler:
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id) self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs(): for seq in aborted_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
...@@ -396,32 +407,26 @@ class Scheduler: ...@@ -396,32 +407,26 @@ class Scheduler:
def _schedule_running( def _schedule_running(
self, self,
running_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerRunningOutputs]: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests. Running queue should include decode and chunked prefill requests.
Args: Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted. when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted. in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
Returns: Returns:
A tuple of remaining running queue (should be always 0) after SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
...@@ -434,10 +439,9 @@ class Scheduler: ...@@ -434,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot # NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state. # to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt. running_queue = self.running
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( num_running_tokens = self._get_num_new_tokens(
...@@ -501,7 +505,7 @@ class Scheduler: ...@@ -501,7 +505,7 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
return running_queue, SchedulerRunningOutputs( return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
preempted=preempted, preempted=preempted,
...@@ -513,12 +517,10 @@ class Scheduler: ...@@ -513,12 +517,10 @@ class Scheduler:
def _schedule_swapped( def _schedule_swapped(
self, self,
swapped_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerSwappedInOutputs]: ) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out. """Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and It schedules swapped requests as long as it fits `budget` and
...@@ -526,20 +528,16 @@ class Scheduler: ...@@ -526,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups. `budget` and `curr_loras` are updated based on scheduled seq_groups.
Args: Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in. when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in. in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
Returns: Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs. SchedulerSwappedInOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
...@@ -547,10 +545,10 @@ class Scheduler: ...@@ -547,10 +545,10 @@ class Scheduler:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = [] infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
...@@ -615,7 +613,7 @@ class Scheduler: ...@@ -615,7 +613,7 @@ class Scheduler:
swapped_queue.extendleft(leftover_swapped) swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs( return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
...@@ -642,11 +640,10 @@ class Scheduler: ...@@ -642,11 +640,10 @@ class Scheduler:
def _schedule_prefills( def _schedule_prefills(
self, self,
waiting_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerPrefillOutputs]: ) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage. """Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...@@ -658,8 +655,6 @@ class Scheduler: ...@@ -658,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups. `budget` and `curr_loras` are updated based on scheduled seq_groups.
Args: Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled. when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
...@@ -670,14 +665,12 @@ class Scheduler: ...@@ -670,14 +665,12 @@ class Scheduler:
all tokens. all tokens.
Returns: Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs. SchedulerSwappedInOutputs.
""" """
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = [] seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified. waiting_queue = self.waiting
waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences: Deque[SequenceGroup] = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
...@@ -756,7 +749,7 @@ class Scheduler: ...@@ -756,7 +749,7 @@ class Scheduler:
if len(seq_groups) > 0: if len(seq_groups) > 0:
self.prev_prompt = True self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs( return SchedulerPrefillOutputs(
seq_groups=seq_groups, seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups, ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
...@@ -783,53 +776,43 @@ class Scheduler: ...@@ -783,53 +776,43 @@ class Scheduler:
seq_group.lora_int_id for seq_group in self.running seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None if seq_group.lora_int_id > 0) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting, prefills = SchedulerPrefillOutputs.create_empty()
SchedulerPrefillOutputs.create_empty()) running_scheduled = SchedulerRunningOutputs.create_empty()
remaining_running, running_scheduled = ( swapped_in = SchedulerSwappedInOutputs.create_empty()
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
# If any requests are swapped, prioritized swapped requests. # If any requests are swapped, prioritized swapped requests.
if not self.swapped: if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills( prefills = self._schedule_prefills(budget,
self.waiting, budget, curr_loras, enable_chunking=False) curr_loras,
enable_chunking=False)
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
# Don't schedule decodes if prefills are scheduled. # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills. # only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0: if len(prefills.seq_groups) == 0:
remaining_running, running_scheduled = self._schedule_running( running_scheduled = self._schedule_running(budget,
self.running, curr_loras,
budget, enable_chunking=False)
curr_loras,
fcfs_policy,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len( if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped( swapped_in = self._schedule_swapped(budget, curr_loras)
self.swapped, budget, curr_loras, fcfs_policy)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups]) [s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups]) [s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) + preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)) len(running_scheduled.swapped_out))
...@@ -875,42 +858,32 @@ class Scheduler: ...@@ -875,42 +858,32 @@ class Scheduler:
) )
curr_loras: Set[int] = set() curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting, prefills = SchedulerPrefillOutputs.create_empty()
SchedulerPrefillOutputs.create_empty()) swapped_in = SchedulerSwappedInOutputs.create_empty()
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
# Decoding should be always scheduled first by fcfs. # Decoding should be always scheduled first by fcfs.
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") running_scheduled = self._schedule_running(budget,
remaining_running, running_scheduled = self._schedule_running( curr_loras,
self.running, enable_chunking=True)
budget,
curr_loras,
fcfs_policy,
enable_chunking=True)
# Schedule swapped out requests. # Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in. # If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len( if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped( swapped_in = self._schedule_swapped(budget, curr_loras)
self.swapped, budget, curr_loras, fcfs_policy)
# Schedule new prefills. # Schedule new prefills.
remaining_waiting, prefills = self._schedule_prefills( prefills = self._schedule_prefills(budget,
self.waiting, budget, curr_loras, enable_chunking=True) curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups]) [s.seq_group for s in running_scheduled.decode_seq_groups])
...@@ -921,7 +894,6 @@ class Scheduler: ...@@ -921,7 +894,6 @@ class Scheduler:
self.running.extend( self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups]) [s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=(prefills.seq_groups +
...@@ -1029,7 +1001,6 @@ class Scheduler: ...@@ -1029,7 +1001,6 @@ class Scheduler:
token_chunk_size=token_chunk_size, token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
...@@ -1058,13 +1029,16 @@ class Scheduler: ...@@ -1058,13 +1029,16 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]: remaining: Deque[SequenceGroup] = deque()
self._finished_requests_ids += [ for seq_group in self.running:
seq_group.request_id for seq_group in queue if seq_group.is_finished():
if seq_group.is_finished() # Add the finished requests to the finished requests list.
] # This list will be used to update the Mamba cache in the
self.running = deque(seq_group for seq_group in self.running # next step.
if not seq_group.is_finished()) self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
self.running = remaining
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
......
...@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions. ...@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions.
""" """
import ctypes import ctypes
import glob
import os
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
...@@ -36,24 +33,25 @@ class Function: ...@@ -36,24 +33,25 @@ class Function:
argtypes: List[Any] argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str: def find_loaded_library(lib_name) -> Optional[str]:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa """
lib_folder = "cuda_runtime" According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
lib_name = "libcudart.so.*[0-9]" the file `/proc/self/maps` contains the memory maps of the process, which includes the
lib_path = None shared libraries loaded by the process. We can use this file to find the path of the
for path in sys.path: a loaded library.
nvidia_path = os.path.join(path, "nvidia") """ # noqa
if not os.path.exists(nvidia_path): found = False
continue with open("/proc/self/maps") as f:
candidate_lib_paths = glob.glob( for line in f:
os.path.join(nvidia_path, lib_folder, "lib", lib_name)) if lib_name in line:
if candidate_lib_paths and not lib_path: found = True
lib_path = candidate_lib_paths[0] break
if lib_path: if not found:
break # the library is not loaded in the current process
if not lib_path: return None
raise ValueError(f"{lib_name} not found in the system path {sys.path}") start = line.index("/")
return lib_path path = line[start:].strip()
return path
class CudaRTLibrary: class CudaRTLibrary:
...@@ -100,7 +98,9 @@ class CudaRTLibrary: ...@@ -100,7 +98,9 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:
so_file = get_pytorch_default_cudart_library_path() so_file = find_loaded_library("libcudart.so")
assert so_file is not None, \
"libcudart.so is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache: if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file) lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib CudaRTLibrary.path_to_library_cache[so_file] = lib
......
...@@ -145,6 +145,7 @@ def can_actually_p2p( ...@@ -145,6 +145,7 @@ def can_actually_p2p(
p_tgt.start() p_tgt.start()
p_src.join() p_src.join()
p_tgt.join() p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = [] result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt): for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get() a = result_queue.get()
...@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# wrap raised exception to provide more information # wrap raised exception to provide more information
raise RuntimeError( raise RuntimeError(
f"Error happened when batch testing " f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}") from e f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
result = pickle.loads(returned.stdout) result = pickle.loads(returned.stdout)
for _i, _j, r in zip(batch_src, batch_tgt, result): for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r cache[f"{_i}->{_j}"] = r
......
...@@ -9,7 +9,7 @@ from unittest.mock import patch ...@@ -9,7 +9,7 @@ from unittest.mock import patch
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -153,9 +153,7 @@ class Handle: ...@@ -153,9 +153,7 @@ class Handle:
buffer: Optional[ShmRingBuffer] = None buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue: class MessageQueue:
...@@ -189,38 +187,36 @@ class MessageQueue: ...@@ -189,38 +187,36 @@ class MessageQueue:
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks) max_chunks)
self.local_socket = context.socket(PUB) # XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port() local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}") self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0 self.current_idx = 0
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
local_subscribe_port = None local_subscribe_port = None
local_sync_port = None
self.local_socket = None self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1 self.current_idx = -1
if n_remote_reader > 0: if n_remote_reader > 0:
# for remote readers, we will: # for remote readers, we will:
# create a publish-subscribe socket to communicate large data # create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB) self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else: else:
remote_subscribe_port = None remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True self._is_writer = True
self._is_local_reader = False self._is_local_reader = False
...@@ -233,9 +229,7 @@ class MessageQueue: ...@@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks=local_reader_ranks, local_reader_ranks=local_reader_ranks,
buffer=self.buffer, buffer=self.buffer,
local_subscribe_port=local_subscribe_port, local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port, remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
) )
logger.info("vLLM message queue communication handle: %s", self.handle) logger.info("vLLM message queue communication handle: %s", self.handle)
...@@ -264,12 +258,7 @@ class MessageQueue: ...@@ -264,12 +258,7 @@ class MessageQueue:
self.local_socket.connect( self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None self.remote_socket = None
self.remote_sync_socket = None
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
self.current_idx = -1 self.current_idx = -1
...@@ -278,17 +267,12 @@ class MessageQueue: ...@@ -278,17 +267,12 @@ class MessageQueue:
self._is_remote_reader = True self._is_remote_reader = True
self.local_socket = None self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB) self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect( self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self return self
def wait_until_ready(self): def wait_until_ready(self):
...@@ -300,29 +284,27 @@ class MessageQueue: ...@@ -300,29 +284,27 @@ class MessageQueue:
# local readers # local readers
for i in range(self.n_local_reader): for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv() # wait for subscription messages from all local readers
assert recv == b"READY" self.local_socket.recv()
self.local_sync_socket.send(b"READY")
if self.n_local_reader > 0: if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY") self.local_socket.send(b"READY")
# remote readers # remote readers
for i in range(self.n_remote_reader): for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv() # wait for subscription messages from all remote readers
assert recv == b"READY" self.remote_socket.recv()
self.remote_sync_socket.send(b"READY")
if self.n_remote_reader > 0: if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY") self.remote_socket.send(b"READY")
elif self._is_local_reader: elif self._is_local_reader:
self.local_sync_socket.send(b"READY") # wait for the writer to send a message
recv = self.local_sync_socket.recv()
assert recv == b"READY"
recv = self.local_socket.recv() recv = self.local_socket.recv()
assert recv == b"READY" assert recv == b"READY"
elif self._is_remote_reader: elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY") # wait for the writer to send a message
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
recv = self.remote_socket.recv() recv = self.remote_socket.recv()
assert recv == b"READY" assert recv == b"READY"
......
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
class TpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
...@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) ...@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]], tensor_dict: Dict[str, Union[torch.Tensor, Any]]
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts: """Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata. by its metadata.
2. A list of tensors. 2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
""" """
metadata_list: List[Tuple[str, Any]] = [] metadata_list: List[Tuple[str, Any]] = []
tensor_list = [] tensor_list: List[torch.Tensor] = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries.")
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here, # Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device # because it contains not only the device type but also the device
...@@ -68,31 +62,13 @@ def _split_tensor_dict( ...@@ -68,31 +62,13 @@ def _split_tensor_dict(
# receiving side will set the device index. # receiving side will set the device index.
device = value.device.type device = value.device.type
metadata_list.append( metadata_list.append(
(prefix + key, TensorMetadata(device, value.dtype, (key, TensorMetadata(device, value.dtype, value.size())))
value.size())))
tensor_list.append(value) tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else: else:
metadata_list.append((prefix + key, value)) metadata_list.append((key, value))
return metadata_list, tensor_list return metadata_list, tensor_list
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
class GroupCoordinator: class GroupCoordinator:
""" """
PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup wrapper for a group of processes.
...@@ -133,6 +109,7 @@ class GroupCoordinator: ...@@ -133,6 +109,7 @@ class GroupCoordinator:
torch_distributed_backend: Union[str, Backend], torch_distributed_backend: Union[str, Backend],
use_pynccl: bool, use_pynccl: bool,
use_custom_allreduce: bool, use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
): ):
...@@ -164,6 +141,7 @@ class GroupCoordinator: ...@@ -164,6 +141,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
...@@ -190,6 +168,12 @@ class GroupCoordinator: ...@@ -190,6 +168,12 @@ class GroupCoordinator:
else: else:
self.ca_comm = None self.ca_comm = None
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.shm_broadcast import ( from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue) MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None self.mq_broadcaster: Optional[MessageQueue] = None
...@@ -243,6 +227,13 @@ class GroupCoordinator: ...@@ -243,6 +227,13 @@ class GroupCoordinator:
ca_comm = self.ca_comm ca_comm = self.ca_comm
maybe_ca_context = nullcontext( maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture() ) if ca_comm is None else ca_comm.capture()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context: with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective # In graph mode, we have to be very careful about the collective
# operations. The current status is: # operations. The current status is:
...@@ -282,6 +273,12 @@ class GroupCoordinator: ...@@ -282,6 +273,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if self.world_size == 1: if self.world_size == 1:
return input_ return input_
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)
if ca_comm is not None: if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_) out = ca_comm.custom_all_reduce(input_)
if out is not None: if out is not None:
...@@ -289,6 +286,9 @@ class GroupCoordinator: ...@@ -289,6 +286,9 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled): if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_) pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
return input_ return input_
...@@ -300,6 +300,12 @@ class GroupCoordinator: ...@@ -300,6 +300,12 @@ class GroupCoordinator:
return input_ return input_
assert -input_.dim() <= dim < input_.dim(), ( assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
if dim < 0: if dim < 0:
# Convert negative dim to positive. # Convert negative dim to positive.
dim += input_.dim() dim += input_.dim()
...@@ -536,7 +542,7 @@ class GroupCoordinator: ...@@ -536,7 +542,7 @@ class GroupCoordinator:
device=value.device) device=value.device)
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip broadcasting empty tensors. # Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor) tensor_dict[key] = tensor
continue continue
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
...@@ -553,9 +559,9 @@ class GroupCoordinator: ...@@ -553,9 +559,9 @@ class GroupCoordinator:
group=group, group=group,
async_op=True) async_op=True)
async_handles.append(handle) async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor) tensor_dict[key] = tensor
else: else:
_update_nested_dict(tensor_dict, key, value) tensor_dict[key] = value
for async_handle in async_handles: for async_handle in async_handles:
async_handle.wait() async_handle.wait()
return tensor_dict return tensor_dict
...@@ -563,7 +569,8 @@ class GroupCoordinator: ...@@ -563,7 +569,8 @@ class GroupCoordinator:
def send_tensor_dict( def send_tensor_dict(
self, self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]], tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank. NOTE: `dst` is the local rank of the source rank.
...@@ -572,6 +579,11 @@ class GroupCoordinator: ...@@ -572,6 +579,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1: if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group group = self.device_group
metadata_group = self.cpu_group metadata_group = self.cpu_group
...@@ -592,6 +604,12 @@ class GroupCoordinator: ...@@ -592,6 +604,12 @@ class GroupCoordinator:
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip sending empty tensors. # Skip sending empty tensors.
continue continue
# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.send(tensor, torch.distributed.send(tensor,
...@@ -606,7 +624,8 @@ class GroupCoordinator: ...@@ -606,7 +624,8 @@ class GroupCoordinator:
def recv_tensor_dict( def recv_tensor_dict(
self, self,
src: Optional[int] = None src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary. """Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
...@@ -615,6 +634,11 @@ class GroupCoordinator: ...@@ -615,6 +634,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1: if not torch.distributed.is_initialized() or self.world_size == 1:
return None return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group group = self.device_group
metadata_group = self.cpu_group metadata_group = self.cpu_group
...@@ -631,8 +655,18 @@ class GroupCoordinator: ...@@ -631,8 +655,18 @@ class GroupCoordinator:
device=value.device) device=value.device)
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip broadcasting empty tensors. # Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor) tensor_dict[key] = tensor
continue continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.recv(tensor, torch.distributed.recv(tensor,
...@@ -643,9 +677,15 @@ class GroupCoordinator: ...@@ -643,9 +677,15 @@ class GroupCoordinator:
torch.distributed.recv(tensor, torch.distributed.recv(tensor,
src=self.ranks[src], src=self.ranks[src],
group=group) group=group)
_update_nested_dict(tensor_dict, key, tensor) if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else: else:
_update_nested_dict(tensor_dict, key, value) tensor_dict[key] = value
return tensor_dict return tensor_dict
def barrier(self): def barrier(self):
...@@ -673,8 +713,8 @@ class GroupCoordinator: ...@@ -673,8 +713,8 @@ class GroupCoordinator:
size: torch.Size, size: torch.Size,
dtype: torch.dtype, dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor: src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank.""" """Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the destination rank.""" """NOTE: `src` is the local rank of the source rank."""
if src is None: if src is None:
src = (self.rank_in_group - 1) % self.world_size src = (self.rank_in_group - 1) % self.world_size
...@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int, ...@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=False, use_pynccl=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_tpu_communicator=False,
) )
...@@ -735,6 +776,7 @@ def init_model_parallel_group( ...@@ -735,6 +776,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=True, use_pynccl=True,
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster, use_message_queue_broadcaster=use_message_queue_broadcaster,
) )
......
...@@ -6,6 +6,11 @@ from typing import Sequence, Tuple ...@@ -6,6 +6,11 @@ from typing import Sequence, Tuple
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator): def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
...@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, ...@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
If the number of layers is not divisible by the number of partitions, If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers. the last partition will have the remaining layers.
""" """
layers_per_partition = num_hidden_layers // pp_size partition_list_str = envs.VLLM_PP_LAYER_PARTITION
start_layer = pp_rank * layers_per_partition if partition_list_str is not None:
end_layer = start_layer + layers_per_partition try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1: if pp_rank == pp_size - 1:
end_layer = num_hidden_layers end_layer = num_hidden_layers
return (start_layer, end_layer) return (start_layer, end_layer)
...@@ -632,9 +632,9 @@ class EngineArgs: ...@@ -632,9 +632,9 @@ class EngineArgs:
'--preemption-mode', '--preemption-mode',
type=str, type=str,
default=None, default=None,
help='If \'recompute\', the engine performs preemption by block ' help='If \'recompute\', the engine performs preemption by '
'swapping; If \'swap\', the engine performs preemption by block ' 'recomputing; If \'swap\', the engine performs preemption by '
'swapping.') 'block swapping.')
parser.add_argument( parser.add_argument(
"--served-model-name", "--served-model-name",
...@@ -676,8 +676,8 @@ class EngineArgs: ...@@ -676,8 +676,8 @@ class EngineArgs:
# bitsandbytes quantization needs a specific model loader # bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent # so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \ self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes": self.load_format != "bitsandbytes":
raise ValueError( raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support " "BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}") f"'bitsandbytes' load format, but got {self.load_format}")
...@@ -754,10 +754,14 @@ class EngineArgs: ...@@ -754,10 +754,14 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
use_spec_decode = self.speculative_model is not None use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (
model_config.contains_seqlen_agnostic_layers(
parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora
and not self.enable_prompt_adapter and not self.enable_prompt_adapter
and not self.enable_prefix_caching): and not self.enable_prefix_caching
and not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(
"Chunked prefill is enabled by default for models with " "Chunked prefill is enabled by default for models with "
...@@ -788,6 +792,7 @@ class EngineArgs: ...@@ -788,6 +792,7 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\ draft_token_acceptance_method=\
......
...@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping, ...@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm.envs as envs import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
...@@ -407,11 +408,15 @@ class AsyncLLMEngine: ...@@ -407,11 +408,15 @@ class AsyncLLMEngine:
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync if distributed_executor_backend == "ray":
executor_class = TPUExecutorAsync initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino": elif engine_config.device_config.device_type == "openvino":
...@@ -924,6 +929,14 @@ class AsyncLLMEngine: ...@@ -924,6 +929,14 @@ class AsyncLLMEngine:
else: else:
return self.engine.get_model_config() return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray: if self.engine_use_ray:
...@@ -932,6 +945,22 @@ class AsyncLLMEngine: ...@@ -932,6 +945,22 @@ class AsyncLLMEngine:
else: else:
return self.engine.get_decoding_config() return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
......
...@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, ...@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
...@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, ...@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (
get_tokenizer_group) AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter from vllm.utils import Counter
...@@ -408,8 +406,14 @@ class LLMEngine: ...@@ -408,8 +406,14 @@ class LLMEngine:
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor if distributed_executor_backend == "ray":
executor_class = TPUExecutor initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
...@@ -485,29 +489,21 @@ class LLMEngine: ...@@ -485,29 +489,21 @@ class LLMEngine:
return self.tokenizer return self.tokenizer
def get_tokenizer( def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request) return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
# def get_tokenizer_for_seq(self, # def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# sequence: Sequence) -> "PreTrainedTokenizer":
# return self.get_tokenizer_group().get_lora_tokenizer( # return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request) # sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: def _init_tokenizer(self) -> BaseTokenizerGroup:
init_kwargs = dict( return init_tokenizer_from_configs(
tokenizer_id=self.model_config.tokenizer, model_config=self.model_config,
enable_lora=bool(self.lora_config), scheduler_config=self.scheduler_config,
max_num_seqs=self.scheduler_config.max_num_seqs, parallel_config=self.parallel_config,
max_input_length=None, enable_lora=bool(self.lora_config))
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
...@@ -769,10 +765,22 @@ class LLMEngine: ...@@ -769,10 +765,22 @@ class LLMEngine:
"""Gets the model configuration.""" """Gets the model configuration."""
return self.model_config return self.model_config
def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig: def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration.""" """Gets the decoding configuration."""
return self.decoding_config return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups() return sum(scheduler.get_num_unfinished_seq_groups()
...@@ -963,8 +971,9 @@ class LLMEngine: ...@@ -963,8 +971,9 @@ class LLMEngine:
model_output: Optional[List[SamplerOutput]] = None) -> None: model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
stats = self._get_stats(scheduler_outputs, model_output)
for logger in self.stat_loggers.values(): for logger in self.stat_loggers.values():
logger.log(self._get_stats(scheduler_outputs, model_output)) logger.log(stats)
def _get_stats( def _get_stats(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment