Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
......@@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,
......
......@@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,
......
......@@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
......@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
)
return spec_decode_worker
......@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
......@@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
......@@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
......@@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
......@@ -213,6 +220,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
self.generators = scorer_runner.get_generators(
) if scorer_runner else None
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
......@@ -237,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
......@@ -484,7 +495,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
Returns True iff there are remaining sequences to process.
Returns True if there are remaining sequences to process.
"""
assert self.rank != self._driver_rank
......@@ -522,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
with Timer() as proposal_timer:
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
with Timer() as scoring_timer:
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)
return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots)
k=execute_model_req.num_lookahead_slots,
stage_times=stage_times)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
......@@ -591,20 +611,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
......@@ -648,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
......@@ -725,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return
logger.info(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)
def _create_dummy_logprob_lists(
self,
batch_size: int,
......
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple
......@@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
yield
finally:
torch.cuda.nvtx.range_pop()
class Timer:
"""Basic timer context manager for measuring CPU time.
"""
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time_s = self.end_time - self.start_time
self.elapsed_time_ms = self.elapsed_time_s * 1000
......@@ -15,7 +15,7 @@ try:
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.ai import SpanAttributes as BaseSpanAttributes
from opentelemetry.semconv_ai import SpanAttributes as BaseSpanAttributes
from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator)
......@@ -60,7 +60,7 @@ def get_span_exporter(endpoint):
OTLPSpanExporter)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter)
OTLPSpanExporter) # type: ignore
else:
raise ValueError(
f"Unsupported OTLP protocol '{protocol}' is configured")
......
......@@ -5,10 +5,11 @@ from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChameleonConfig, ChatGLMConfig,
DbrxConfig, JAISConfig,
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
MPTConfig, NemotronConfig,
RWConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
......@@ -18,7 +19,6 @@ else:
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chameleon": ChameleonConfig,
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
......@@ -27,6 +27,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"jais": JAISConfig,
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
__all__ = [
"ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig",
"DbrxConfig",
"MPTConfig",
"RWConfig",
"InternVLChatConfig",
"JAISConfig",
"MedusaConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
]
from typing import List, Optional
from transformers import PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class ChameleonConfig(PretrainedConfig):
model_type = "chameleon"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
model_parallel_size=1,
swin_norm=False,
vq_config=None,
vocabulary_map=None,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_bias = mlp_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.model_parallel_size = model_parallel_size
self.swin_norm = swin_norm
if vq_config is None:
vq_config = {}
self.vq_config = ChameleonVQVAEConfig(**vq_config)
self.vocabulary_map = vocabulary_map
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, "
f"`type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in [
"linear", "dynamic"
]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', "
f"'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, "
f"got {rope_scaling_factor}")
class ChameleonVQVAEConfig(PretrainedConfig):
model_type = "chameleon_vqgan"
def __init__(
self,
embed_dim: int = 256,
num_embeddings: int = 8192,
double_latent: bool = False,
latent_channels: int = 256,
resolution: int = 512,
in_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
num_res_blocks: int = 2,
attn_resolutions: Optional[List[int]] = None,
dropout: float = 0.0,
attn_type: str = "vanilla",
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.resolution = resolution
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.dropout = dropout
self.attn_type = attn_type
self.initializer_range = initializer_range
# Adapted from
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class InternVLChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Nemotron model configuration"""
from transformers import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NemotronConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`NemotronModel`]. It is used to instantiate an Nemotron model
according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the Nemotron-8B.
Configuration objects inherit from [`PretrainedConfig`] and can be
used to control the model outputs. Read the documentation from
[`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Nemotron model. Defines the number of
different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer decoder.
head_dim (`int`, *optional*, defaults to None):
Projection weights dimension in multi-head attention. Set to
hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention
(MQA) otherwise GQA is used. When converting a multi-head
checkpoint to a GQA checkpoint, each group key and value
head should be constructed by meanpooling all the original
heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
is not specified, will default to `num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the
decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used
with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE
embeddings. Currently supports two scaling strategies: linear
and dynamic. Their scaling factor must be a float greater than 1.
The expected format is `{"type": strategy name,
"factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output
projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj and down_proj layers in the MLP
layers.
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "nemotron"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=6144,
intermediate_size=24576,
num_hidden_layers=32,
num_attention_heads=48,
head_dim=None,
num_key_value_heads=None,
hidden_act="relu2",
max_position_embeddings=4096,
initializer_range=0.0134,
norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=2,
eos_token_id=3,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
rope_percent=0.5,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
head_dim = head_dim or kwargs.get("kv_channels", None)
self.head_dim = head_dim if head_dim is not None else (
hidden_size // num_attention_heads)
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
rope_percent = rope_percent or kwargs.get("rope_percentage", None)
self.rope_percent = rope_percent
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, "
f"`type` and `factor`, got {self.rope_scaling}")
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in [
"linear", "dynamic"
]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', "
f"'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, got "
f"{rope_scaling_factor}")
......@@ -37,8 +37,10 @@ class Detokenizer:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
......
from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
......@@ -14,6 +14,22 @@ else:
RayTokenizerGroupPool = None # type: ignore
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
......@@ -34,4 +50,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
......@@ -47,17 +49,17 @@ class BaseTokenizerGroup(ABC):
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
......
......@@ -3,21 +3,19 @@ import os
from typing import List, Optional
try:
from ray.exceptions import ActorDiedError
from ray.exceptions import ActorDiedError # type: ignore
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
logger = init_logger(__name__)
......@@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
**self._tokenizer_config, )
self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self._worker_cls).options(**ray_actor_options) # type: ignore
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None
......@@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return len(self.tokenizer_actors)
def ping(self):
return ray.get(
[actor.ping.remote() for actor in self.tokenizer_actors])
return ray.get([
actor.ping.remote() # type: ignore
for actor in self.tokenizer_actors
])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
......@@ -208,15 +208,15 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)
......
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.utils import LRUCache
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
......@@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs if enable_lora else 0)
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
......@@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
encoded_tokens: List[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
......@@ -72,9 +70,9 @@ class TokenizerGroup(BaseTokenizerGroup):
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......@@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......@@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.importing import HAS_TRITON
__all__ = [
"maybe_set_triton_cache_manager",
]
__all__ = ["HAS_TRITON"]
if HAS_TRITON:
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.libentry import libentry
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
from importlib.util import find_spec
from vllm.logger import init_logger
logger = init_logger(__name__)
HAS_TRITON = find_spec("triton") is not None
if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions"
" will be not be available.")
# Copied From https://github.com/FlagOpen/FlagGems
import inspect
import triton
class LibEntry(triton.KernelInterface):
def __init__(
self,
fn,
):
self.fn = fn
self.arg_names = fn.arg_names
self.divisibility = 16
self.kernel_cache = dict()
fn = self.fn
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and not p.do_not_specialize
]
self.do_not_specialize_indices = [
p.num for p in self.jit_function.params
if not p.is_constexpr and p.do_not_specialize
]
def key(self, spec_args, dns_args, const_args):
spec_key = [(arg.dtype, arg.data_ptr() %
self.divisibility == 0) if hasattr(arg, "data_ptr") else
(type(arg), arg) for arg in spec_args]
dns_key = [
arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if -(2**31) <= arg and arg <= 2**31 -
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64"
for arg in dns_args
]
# const args passed by position
return tuple(spec_key + dns_key + const_args)
def run(self, *args, **kwargs):
grid = kwargs["grid"]
# collect all the arguments
spec_args = [] # specialize arguments
dns_args = [] # do not specialize arguments
const_args = [] # constexpr arguments
k_args = [] # kernel arguments
for i, arg in enumerate(args):
if i in self.specialize_indices:
k_args.append(arg)
spec_args.append(arg)
elif i in self.do_not_specialize_indices:
k_args.append(arg)
dns_args.append(arg)
else:
const_args.append(arg)
for p in self.jit_function.params[len(args):]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default
if p.is_constexpr:
const_args.append(val)
elif p.do_not_specialize:
dns_args.append(val)
k_args.append(val)
else:
spec_args.append(val)
k_args.append(val)
entry_key = self.key(spec_args, dns_args, const_args)
if entry_key not in self.kernel_cache:
# compile the kernel also completes the related computations
kernel = self.fn.run(*args, **kwargs)
fn = self.fn
# collect constexpr arguments for grid computation
constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
constexprs[v] = heur({
**dict(zip(fn.arg_names, args)),
**kwargs,
**constexprs,
})
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
# In vLLM, certain kernels like fused_moe_kernel get the
# best_config(as kwargs) from a configuration json file, rather
# than using Autotuner & Heuristics. Therefore, all their constexprs
# (tl.constexpr) are assigned values through the following loop.
for p in self.jit_function.params:
if p.is_constexpr and p.name not in constexprs:
constexprs[p.name] = p.default #default=inspect._empty
self.kernel_cache[entry_key] = (kernel, constexprs)
else:
# load kernel from cache directly
kernel, constexprs = self.kernel_cache[entry_key]
if callable(grid):
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from
# Autotunner & Heuristics when kwargs & captured args conflict,
# captured args have higher priority
# 4. We must filter out captured args with default value firstly
constexprs = {
k: v
for k, v in constexprs.items() if v is not inspect._empty
}
meta = {
**dict(zip(self.arg_names, args)),
**kwargs,
**constexprs,
}
grid = grid(meta)
if isinstance(grid, tuple):
grid = grid + (1, 1)
elif isinstance(grid, list):
grid = grid + [1, 1]
kernel[grid[0:3]](*k_args)
# maintaining the same return type as the JITFunction.run
return kernel
def libentry():
"""
Decorator for triton library entries.
Motivation:
The runtime overhead of Triton kernels is the reason for the lower
performance of small kernels, particularly evident with smaller models.
Using this decorator can reduce Triton runtime overhead.
How:
The `run` function of JITFunction needs to accomplish:
- Parameter binding using inspect
- KernelArg type wrapping
- Cache key calculation
When dealing with small size, these steps can become bottlenecks in
Triton runtime. Libentry simplifies these steps to reduce runtime
overhead, thereby improving the runtime expenses of small kernels.
NOTE:
When Triton is upgraded to version 3.0.0, libentry can be removed,
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
"""
def decorator(fn):
return LibEntry(fn)
return decorator
import math
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
import argparse
import asyncio
import contextlib
import datetime
import enum
import gc
......@@ -17,7 +16,7 @@ from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)
Union, overload)
import numpy as np
import numpy.typing as npt
......@@ -53,6 +52,7 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
class _Sentinel:
......@@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __getitem__(self, key: Hashable) -> T:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)
......@@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache:
value: Optional[T] = self.cache[key]
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
......@@ -287,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
......@@ -295,9 +302,10 @@ def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
......@@ -305,7 +313,8 @@ def merge_async_iterators(
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
......@@ -313,9 +322,17 @@ def merge_async_iterators(
]
async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
......@@ -371,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port() -> int:
port = envs.VLLM_PORT
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
try:
......@@ -404,27 +423,6 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v
def init_kmp_env():
if not is_cpu():
return
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" not in ld_prealod_str:
return
# The time(milliseconds) that a thread should wait after completing the
# execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# dump settings on start up
os.environ['KMP_SETTINGS'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
def chunk_list(lst: List[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
......@@ -491,7 +489,6 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
assert cache_dtype != "fp8"
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
......@@ -507,7 +504,13 @@ def create_kv_caches_with_random_flash(
key_value_cache = torch.empty(size=key_value_cache_shape,
dtype=torch_dtype,
device=device)
key_value_cache.uniform_(-scale, scale)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
......@@ -524,6 +527,12 @@ def create_kv_caches_with_random(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
......@@ -600,8 +609,8 @@ class CudaMemoryProfiler:
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():
torch.xpu.reset_peak_memory_stats(self.device)
mem = torch.xpu.max_memory_allocated(self.device)
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
def __enter__(self):
......@@ -719,6 +728,54 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
......@@ -881,27 +938,6 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def error_on_invalid_device_count_status():
cache_entries = 0
with contextlib.suppress(Exception):
# future pytorch will fix the issue, device_count will not be cached
# at that time, `.cache_info().currsize` will error out
cache_entries = torch.cuda.device_count.cache_info().currsize
if cache_entries != 0:
# the function is already called, and the result is cached
remembered = torch.cuda.device_count()
current = cuda_device_count_stateless()
if remembered > current:
raise RuntimeError(
"The number of CUDA devices has changed since the first "
"call to torch.cuda.device_count(). This is not allowed "
"and may result in undefined behavior. Please check out "
"https://github.com/vllm-project/vllm/issues/6056 to "
"find the first call to torch.cuda.device_count() "
"and defer it until the engine is up. Or you can set "
"CUDA_VISIBLE_DEVICES to the GPUs you want to use.")
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment