"vllm/vscode:/vscode.git/clone" did not exist on "b7cbc254169128a4203d111f3b87edaa17839a32"
Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
......@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"ROCFlashAttention does not support blocksparse attention.")
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"Torch SPDA does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
......@@ -156,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
......@@ -173,7 +173,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
......
......@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"XFormer does not support block-sparse attention.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"XFormers does not support attention logits soft capping.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -82,7 +83,7 @@ class Attention(nn.Module):
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params)
blocksparse_params, logits_soft_cap)
def forward(
self,
......
......@@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
......@@ -31,7 +34,7 @@ class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 192, 256]
return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import torch
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
......@@ -31,6 +32,7 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"DeepseekV2ForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
......@@ -38,6 +40,10 @@ _PP_SUPPORTED_MODELS = [
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM",
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
]
......@@ -195,13 +201,17 @@ class ModelConfig:
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compress-tensors uses a "compression_config" key
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm","awq"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......@@ -240,9 +250,7 @@ class ModelConfig:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors")):
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
......@@ -281,6 +289,10 @@ class ModelConfig:
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
raise ValueError(
"BitAndBytes with enforce_eager = False is not supported yet.")
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""
......@@ -590,9 +602,11 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
......@@ -716,7 +730,7 @@ class ParallelConfig:
backend)
self._verify_args()
self.rank = 0
self.rank: int = 0
@property
def use_ray(self) -> bool:
......@@ -842,6 +856,7 @@ class SchedulerConfig:
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "auto") -> None:
if device == "auto":
......@@ -892,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
......@@ -1053,7 +1069,7 @@ class SpeculativeConfig:
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config,
speculative_draft_tensor_parallel_size))
speculative_draft_tensor_parallel_size, draft_hf_config))
if num_speculative_tokens is None:
raise ValueError(
......@@ -1080,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
@staticmethod
......@@ -1121,15 +1138,23 @@ class SpeculativeConfig:
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
......@@ -1166,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
):
"""Create a SpeculativeConfig object.
......@@ -1198,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
......@@ -1212,6 +1240,7 @@ class SpeculativeConfig:
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args()
......@@ -1281,7 +1310,7 @@ class LoRAConfig:
long_lora_scaling_factors: Optional[Tuple[float]] = None
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
# TODO: Increase the range of rank
possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
......@@ -1527,15 +1556,21 @@ def _get_and_verify_max_len(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(
msg = (
f"User-specified max_model_len ({max_model_len}) is greater "
"than the derived max_model_len "
f"({max_len_key}={derived_max_model_len} or model_max_length="
f"than the derived max_model_len ({max_len_key}="
f"{derived_max_model_len} or model_max_length="
f"{model_max_length} in model's config.json). This may lead "
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
"to incorrect model outputs or CUDA errors.")
if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
logger.warning(
"%s Make sure the value is correct and within the "
"model context size.", msg)
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set "
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
return int(max_model_len)
......
from pathlib import Path
from typing import Mapping, Optional
from typing import Mapping, MutableMapping, Optional
from urllib.parse import urlparse
import aiohttp
......@@ -40,7 +40,7 @@ class HTTPConnection:
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]:
def _headers(self, **extras: str) -> MutableMapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response(
......
......@@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
......@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -313,6 +312,7 @@ class Scheduler:
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
......@@ -344,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
......@@ -374,6 +384,7 @@ class Scheduler:
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
......@@ -396,32 +407,26 @@ class Scheduler:
def _schedule_running(
self,
running_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerRunningOutputs]:
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = []
......@@ -434,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
running_queue = self.running
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
......@@ -501,7 +505,7 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
return running_queue, SchedulerRunningOutputs(
return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
preempted=preempted,
......@@ -513,12 +517,10 @@ class Scheduler:
def _schedule_swapped(
self,
swapped_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerSwappedInOutputs]:
) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
......@@ -526,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
......@@ -547,10 +545,10 @@ class Scheduler:
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
......@@ -615,7 +613,7 @@ class Scheduler:
swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs(
return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
......@@ -642,11 +640,10 @@ class Scheduler:
def _schedule_prefills(
self,
waiting_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerPrefillOutputs]:
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
......@@ -658,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
......@@ -670,14 +665,12 @@ class Scheduler:
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
......@@ -756,7 +749,7 @@ class Scheduler:
if len(seq_groups) > 0:
self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs(
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
......@@ -783,53 +776,43 @@ class Scheduler:
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=False)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0:
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=False)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
......@@ -875,42 +858,32 @@ class Scheduler:
)
curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Decoding should be always scheduled first by fcfs.
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=True)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=True)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
# Schedule new prefills.
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=True)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
......@@ -921,7 +894,6 @@ class Scheduler:
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
......@@ -1029,7 +1001,6 @@ class Scheduler:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
......@@ -1058,13 +1029,16 @@ class Scheduler:
self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
self.running = remaining
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
......
......@@ -4,9 +4,6 @@ convenient for use when we just need to call a few functions.
"""
import ctypes
import glob
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
......@@ -36,24 +33,25 @@ class Function:
argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder = "cuda_runtime"
lib_name = "libcudart.so.*[0-9]"
lib_path = None
for path in sys.path:
nvidia_path = os.path.join(path, "nvidia")
if not os.path.exists(nvidia_path):
continue
candidate_lib_paths = glob.glob(
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
if candidate_lib_paths and not lib_path:
lib_path = candidate_lib_paths[0]
if lib_path:
break
if not lib_path:
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
return lib_path
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
start = line.index("/")
path = line[start:].strip()
return path
class CudaRTLibrary:
......@@ -100,7 +98,9 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = get_pytorch_default_cudart_library_path()
so_file = find_loaded_library("libcudart.so")
assert so_file is not None, \
"libcudart.so is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
......
......@@ -145,6 +145,7 @@ def can_actually_p2p(
p_tgt.start()
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
......@@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}") from e
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
result = pickle.loads(returned.stdout)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
......
......@@ -9,7 +9,7 @@ from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -153,9 +153,7 @@ class Handle:
buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue:
......@@ -189,38 +187,36 @@ class MessageQueue:
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.local_socket = context.socket(PUB)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True
self._is_local_reader = False
......@@ -233,9 +229,7 @@ class MessageQueue:
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)
logger.info("vLLM message queue communication handle: %s", self.handle)
......@@ -264,12 +258,7 @@ class MessageQueue:
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
......@@ -278,17 +267,12 @@ class MessageQueue:
self._is_remote_reader = True
self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self
def wait_until_ready(self):
......@@ -300,29 +284,27 @@ class MessageQueue:
# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
# wait for subscription messages from all local readers
self.local_socket.recv()
if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
# wait for subscription messages from all remote readers
self.remote_socket.recv()
if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.remote_socket.recv()
assert recv == b"READY"
......
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
class TpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
......@@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
tensor_list: List[torch.Tensor] = []
for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries.")
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
......@@ -68,31 +62,13 @@ def _split_tensor_dict(
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(prefix + key, TensorMetadata(device, value.dtype,
value.size())))
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((prefix + key, value))
metadata_list.append((key, value))
return metadata_list, tensor_list
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
......@@ -133,6 +109,7 @@ class GroupCoordinator:
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
):
......@@ -164,6 +141,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
......@@ -190,6 +168,12 @@ class GroupCoordinator:
else:
self.ca_comm = None
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
......@@ -243,6 +227,13 @@ class GroupCoordinator:
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
......@@ -282,6 +273,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
......@@ -289,6 +286,9 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
......@@ -300,6 +300,12 @@ class GroupCoordinator:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
......@@ -536,7 +542,7 @@ class GroupCoordinator:
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
......@@ -553,9 +559,9 @@ class GroupCoordinator:
group=group,
async_op=True)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
else:
_update_nested_dict(tensor_dict, key, value)
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
......@@ -563,7 +569,8 @@ class GroupCoordinator:
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
......@@ -572,6 +579,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
......@@ -592,6 +604,12 @@ class GroupCoordinator:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
......@@ -606,7 +624,8 @@ class GroupCoordinator:
def recv_tensor_dict(
self,
src: Optional[int] = None
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
......@@ -615,6 +634,11 @@ class GroupCoordinator:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
......@@ -631,8 +655,18 @@ class GroupCoordinator:
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
......@@ -643,9 +677,15 @@ class GroupCoordinator:
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
_update_nested_dict(tensor_dict, key, value)
tensor_dict[key] = value
return tensor_dict
def barrier(self):
......@@ -673,8 +713,8 @@ class GroupCoordinator:
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
......@@ -717,6 +757,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
)
......@@ -735,6 +776,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)
......
......@@ -6,6 +6,11 @@ from typing import Sequence, Tuple
import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
......@@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
return (start_layer, end_layer)
......@@ -632,9 +632,9 @@ class EngineArgs:
'--preemption-mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by block '
'swapping; If \'swap\', the engine performs preemption by block '
'swapping.')
help='If \'recompute\', the engine performs preemption by '
'recomputing; If \'swap\', the engine performs preemption by '
'block swapping.')
parser.add_argument(
"--served-model-name",
......@@ -676,8 +676,8 @@ class EngineArgs:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
......@@ -754,10 +754,14 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (
model_config.contains_seqlen_agnostic_layers(
parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not self.enable_prefix_caching):
and not self.enable_prefix_caching
and not has_seqlen_agnostic_layers):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
......@@ -788,6 +792,7 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
......
......@@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
......@@ -407,11 +408,15 @@ class AsyncLLMEngine:
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
......@@ -924,6 +929,14 @@ class AsyncLLMEngine:
else:
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
......@@ -932,6 +945,22 @@ class AsyncLLMEngine:
else:
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
......
......@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
......@@ -40,8 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
......@@ -408,8 +406,14 @@ class LLMEngine:
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
......@@ -485,29 +489,21 @@ class LLMEngine:
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
# def get_tokenizer_for_seq(self,
# sequence: Sequence) -> "PreTrainedTokenizer":
# def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
......@@ -769,10 +765,22 @@ class LLMEngine:
"""Gets the model configuration."""
return self.model_config
def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()
......@@ -963,8 +971,9 @@ class LLMEngine:
model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
stats = self._get_stats(scheduler_outputs, model_output)
for logger in self.stat_loggers.values():
logger.log(self._get_stats(scheduler_outputs, model_output))
logger.log(stats)
def _get_stats(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment