From 7ca9934fe773edf8680aed287b0a05cb195bd8e4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 6 Feb 2025 04:02:14 -0500 Subject: [PATCH 001/253] [Misc] Update w2 scale loading for GPTQMarlinMoE (#12757) --- tests/weight_loading/models-large.txt | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- .../layers/quantization/gptq_marlin.py | 23 ++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 8ab7f05d7..9c1c11da5 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,5 +1,7 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3c7ef5e00..f18c03133 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module): "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ == - "CompressedTensorsWNA16MoEMethod"): + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 99ab29995..84c53b2c1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size_full) + if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = (intermediate_size_per_partition // - self.quant_config.group_size) + w2_scales_size = (intermediate_size_full + if self.quant_config.desc_act else + intermediate_size_per_partition) + scales_size2 = (w2_scales_size // self.quant_config.group_size) strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 @@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, + {"load_full_w2": self.quant_config.desc_act}) # up_proj scales w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, @@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, + {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - ).to(orig_dtype) + is_k_full=self.is_k_full).to(orig_dtype) -- GitLab From cefd56ee354b915e2fff6b2b5eb1f8b55721fe7e Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 6 Feb 2025 01:02:38 -0800 Subject: [PATCH 002/253] [Docs] Add Google Cloud Slides (#12814) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 09c2c6d35..cd0b1c517 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). -- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing). +- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! -- GitLab From c786e757fae4519256e4ef88a7d4f56c3339d14d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 06:43:12 -0500 Subject: [PATCH 003/253] [Attention] Use FA3 for MLA on Hopper (#12807) Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flash_attn.py | 44 ++++++------------------ vllm/attention/backends/mla/utils.py | 2 ++ vllm/attention/backends/utils.py | 34 ++++++++++++++++++ vllm/v1/attention/backends/flash_attn.py | 30 +++------------- 4 files changed, 51 insertions(+), 59 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6a82127ac..971fe4116 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.envs import VLLM_FLASH_ATTN_VERSION + PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState, + compute_slot_mapping, compute_slot_mapping_start_idx, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, + is_block_tables_empty) from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl): f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: - self.fa_version = 3 if is_fa_version_supported(3) else 2 - else: - self.fa_version = 2 - - if VLLM_FLASH_ATTN_VERSION is not None: - assert VLLM_FLASH_ATTN_VERSION in [2, 3] - self.fa_version = VLLM_FLASH_ATTN_VERSION - - if not is_fa_version_supported(self.fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - self.fa_version, - fa_version_unsupported_reason(self.fa_version)) - - assert is_fa_version_supported(self.fa_version) - def forward( self, layer: AttentionLayer, @@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) else: # prefix-enabled attention @@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl): block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) if decode_meta := attn_metadata.decode_metadata: @@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index cd8c08e5a..e1285d1fa 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -12,6 +12,7 @@ from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) +from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, + fa_version=VLLM_FLASH_ATTN_VERSION, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ad53e4e70..3c5028a66 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch +from vllm import envs from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.logger import logging from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from vllm.worker.model_runner_base import ModelRunnerBase @@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + + def flash_attn_version(): + # if hopper default to FA3, otherwise stick to FA2 for now + # TODO(lucas): profile FA3 on ampere to see if it makes sense to + # use FA3 as default for both + if current_platform.get_device_capability()[0] >= 9: + fa_version = 3 if is_fa_version_supported(3) else 2 + else: + fa_version = 2 + + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + + VLLM_FLASH_ATTN_VERSION = flash_attn_version() +except ImportError: + VLLM_FLASH_ATTN_VERSION = None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 837d7faf4..204afc9f4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,13 +10,10 @@ import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +from vllm.vllm_flash_attn import flash_attn_varlen_func logger = init_logger(__name__) @@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl): "are not implemented for " "FlashAttentionImpl") - # if hopper default to FA3, otherwise stick to FA2 for now - # TODO(lucas): profile FA3 on ampere to see if it makes sense to - # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: - self.fa_version = 3 if is_fa_version_supported(3) else 2 - else: - self.fa_version = 2 - - if VLLM_FLASH_ATTN_VERSION is not None: - assert VLLM_FLASH_ATTN_VERSION in [2, 3] - self.fa_version = VLLM_FLASH_ATTN_VERSION - - if not is_fa_version_supported(self.fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - self.fa_version, - fa_version_unsupported_reason(self.fa_version)) - - assert is_fa_version_supported(self.fa_version) - def forward( self, layer: torch.nn.Module, @@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output @@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, - fa_version=self.fa_version, + fa_version=VLLM_FLASH_ATTN_VERSION, ) return output -- GitLab From e152f295020ea2a7ca37be9cabadfaef78464274 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 6 Feb 2025 06:59:18 -0800 Subject: [PATCH 004/253] [misc] Reduce number of config file requests to HuggingFace (#12797) Signed-off-by: EC2 Default User Signed-off-by: <> Co-authored-by: EC2 Default User --- vllm/transformers_utils/config.py | 36 ++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1c0f20a6e..85056158b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Type, Union import huggingface_hub -from huggingface_hub import (file_exists, hf_hub_download, +from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, try_to_load_from_cache) from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, LocalEntryNotFoundError, @@ -395,18 +395,28 @@ def get_sentence_transformer_tokenizer_config(model: str, - dict: A dictionary containing the configuration parameters for the Sentence Transformer BERT model. """ - for config_name in [ - "sentence_bert_config.json", - "sentence_roberta_config.json", - "sentence_distilbert_config.json", - "sentence_camembert_config.json", - "sentence_albert_config.json", - "sentence_xlm-roberta_config.json", - "sentence_xlnet_config.json", - ]: - encoder_dict = get_hf_file_to_dict(config_name, model, revision) - if encoder_dict: - break + sentence_transformer_config_files = [ + "sentence_bert_config.json", + "sentence_roberta_config.json", + "sentence_distilbert_config.json", + "sentence_camembert_config.json", + "sentence_albert_config.json", + "sentence_xlm-roberta_config.json", + "sentence_xlnet_config.json", + ] + try: + # If model is on HuggingfaceHub, get the repo files + repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) + except Exception as e: + logger.debug("Error getting repo files", e) + repo_files = [] + + encoder_dict = None + for config_name in sentence_transformer_config_files: + if config_name in repo_files or Path(model).exists(): + encoder_dict = get_hf_file_to_dict(config_name, model, revision) + if encoder_dict: + break if not encoder_dict: return None -- GitLab From 1e57b1ee6312325a9dab99918422693c38f2b203 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 7 Feb 2025 00:45:44 +0800 Subject: [PATCH 005/253] [Misc] Remove unnecessary decode call (#12833) --- vllm/inputs/preprocess.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 4d8f28cb0..53f89996f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -260,8 +260,6 @@ class InputPreprocessor: mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) - if isinstance(prompt, list): - prompt = tokenizer.decode(prompt) if mm_processor_kwargs is None: mm_processor_kwargs = {} -- GitLab From 85ac82d228ef6af4e8fc6332d918133e783a0fdb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 7 Feb 2025 00:46:13 +0800 Subject: [PATCH 006/253] [Kernel] Make rotary_embedding ops more flexible with input shape (#12777) --- csrc/pos_encoding_kernels.cu | 103 +++++++++++++++++++--- tests/kernels/test_pos_encoding.py | 31 +++++-- vllm/attention/backends/mla/utils.py | 25 +----- vllm/model_executor/models/deepseek_v2.py | 13 +-- 4 files changed, 115 insertions(+), 57 deletions(-) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 97184a873..c085d31a3 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel( void rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int64_t num_tokens = query.numel() / query.size(-1); + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int positions_ndim = positions.dim(); + + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -165,19 +201,58 @@ and process in batched manner. void batched_rotary_embedding( torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] ) { + // num_tokens = batch_size * seq_len int64_t num_tokens = cos_sin_cache_offsets.size(0); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); + TORCH_CHECK( + positions.size(0) == num_tokens || positions.numel() == num_tokens, + "positions must have the same num_tokens or batch_size as " + "cos_sin_cache_offsets"); + + int positions_ndim = positions.dim(); + // Make sure num_tokens dim is consistent across positions, query, and key. + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have concistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.stride(seq_dim_idx); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 5b7b0fda2..af9bfd2f0 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from itertools import accumulate, product -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import pytest import torch @@ -24,7 +24,21 @@ CUDA_DEVICES = [ ] +def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads * head_size) + + +def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size) + + +TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] + + @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -36,6 +50,7 @@ CUDA_DEVICES = [ @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, + tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], batch_size: int, seq_len: int, num_heads: int, @@ -58,10 +73,8 @@ def test_rotary_embedding( rope = rope.to(dtype=dtype) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first @@ -80,6 +93,7 @@ def test_rotary_embedding( @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -91,6 +105,7 @@ def test_rotary_embedding( @torch.inference_mode() def test_batched_rotary_embedding( is_neox_style: bool, + tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], batch_size: int, seq_len: int, num_heads: int, @@ -113,10 +128,8 @@ def test_batched_rotary_embedding( rope = rope.to(dtype=dtype) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e1285d1fa..c22f7e921 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ) -> torch.Tensor: raise NotImplementedError - def apply_pure_rope( - self, - input_positions: torch.Tensor, - q_pe: torch.Tensor, - k_pe: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - seq_len = input_positions.size(0) - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - - q_pe, k_pe = self.rotary_emb( - input_positions, - q_pe.reshape(seq_len, -1), - k_pe.reshape(seq_len, -1), - ) - q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) - - return q_pe, k_pe - def forward( self, layer: AttentionLayer, @@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") - rope_fn = (self.rotary_emb - if self.use_yarn_rope else self.apply_pure_rope) if is_decode: q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) else: assert is_prefill q = self.q_proj(hidden_states_or_q_c)[0]\ @@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # TODO(lucas): there must be a nicer way to write this line q[..., self.qk_nope_head_dim:], k_pe = \ - rope_fn( + self.rotary_emb( attn_metadata.input_positions, q[..., self.qk_nope_head_dim:], k_pe) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 773f5abe7..0c6f07ce7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module): prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' - self.use_normal_rope = False - else: - self.use_normal_rope = True + self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module): k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = latent_cache[:, :, self.kv_lora_rank:] - if self.use_normal_rope: - seq_len = positions.size(0) - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - q_pe = q_pe.reshape(seq_len, -1) - k_pe = k_pe.reshape(seq_len, -1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - if self.use_normal_rope: - q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) - q[..., self.qk_nope_head_dim:] = q_pe k = torch.empty_like(q) k[..., :self.qk_nope_head_dim] = k_nope -- GitLab From 09b95e36abbce747b52b9c3e7ae4cceaf40076ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Feb 2025 01:09:07 +0800 Subject: [PATCH 007/253] [torch.compile] PyTorch 2.6 and nightly compatibility (#12393) Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 6 +- vllm/compilation/backends.py | 437 +++++++--------------- vllm/compilation/compiler_interface.py | 340 +++++++++++++++++ vllm/compilation/counter.py | 2 +- vllm/compilation/inductor_pass.py | 1 - vllm/compilation/pass_manager.py | 16 +- vllm/config.py | 9 - 8 files changed, 493 insertions(+), 320 deletions(-) create mode 100644 vllm/compilation/compiler_interface.py diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 9d633ad25..143cb4969 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -92,7 +92,7 @@ def test_simple_piecewise_compile(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=5, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen + num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 0404722ba..021bd4cc4 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -322,7 +322,7 @@ def test_toy_llama(): num_graphs_seen=0, num_piecewise_graphs_seen=0, num_piecewise_capturable_graphs_seen=0, - num_inductor_compilations=0, + num_backend_compilations=0, num_cudagraph_caputured=0, ): outputs.append(run_model(llama_config, use_compile=False)) @@ -332,7 +332,7 @@ def test_toy_llama(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, num_piecewise_capturable_graphs_seen=1, - num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): @@ -345,7 +345,7 @@ def test_toy_llama(): 1, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=1 + llama_config.num_layers, # 1 + num_layers - num_inductor_compilations=1 + + num_backend_compilations=1 + llama_config.num_layers, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured=2 * (1 + llama_config.num_layers diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 979890170..b972f03c9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import ast -import copy import dataclasses import os import pprint import time -from collections import defaultdict from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -19,6 +17,7 @@ from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.utils import weak_ref_tensors +from .compiler_interface import EagerAdaptor, InductorAdaptor from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -27,306 +26,128 @@ from .pass_manager import PostGradPassManager logger = init_logger(__name__) -@dataclasses.dataclass -class InductorArtifact: - hash_str: str = "" - file_path: str = "" +class CompilerManager: + """ + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. -class InductorHashCache: + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. """ - Disk format: a Python list of tuples, each tuple is - (runtime_shape, graph_index, hash_str, file_path) - We use list of tuple for readability. - In-memory format: a defaultdict of dict, where the key is - runtime_shape, and the value is a dict of graph_index to hash_str. + def __init__(self, use_inductor: bool): + self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() + cls = InductorAdaptor if use_inductor else EagerAdaptor + self.compiler = cls() - The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, - we don't use json here because json doesn't support int as key. - - TODO: better off-the-shelf solution to serialize the data? - """ + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) - def __init__(self, cache_dir: str, disabled: bool = False): - self.cache: Dict[Optional[int], - Dict[int, InductorArtifact]] = defaultdict(dict) - self.disabled = disabled + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.disable_cache = disable_cache self.cache_dir = cache_dir - self.cache_file_path = os.path.join(cache_dir, - "inductor_hash_cache.py") - if disabled: - return - # set flags so that Inductor and Triton store their cache - # in the cache_dir, then users only need to copy the cache_dir - # to another machine to reuse the cache. - inductor_cache = os.path.join(cache_dir, "inductor_cache") - os.makedirs(inductor_cache, exist_ok=True) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache - triton_cache = os.path.join(cache_dir, "triton_cache") - os.makedirs(triton_cache, exist_ok=True) - os.environ["TRITON_CACHE_DIR"] = triton_cache - if os.path.exists(self.cache_file_path): + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file with open(self.cache_file_path) as f: - self.deserialize(f.read()) - - def deserialize(self, data: str): - # we use ast.literal_eval to parse the data - # because it is a safe way to parse Python literals. - # do not use eval(), it is unsafe. - list_data = ast.literal_eval(data) - for item in list_data: - runtime_shape = item[0] - graph_index = item[1] - hash_str = item[2] - # for compatibility of old version, - # where we don't have file_path. - # NOTE: after running the new code, the file_path - # will be updated. - file_path = "" if len(item) == 3 else item[3] - self.cache[runtime_shape][graph_index] = InductorArtifact( - hash_str=hash_str, file_path=file_path) - - def serialize(self) -> str: - data = [] - for runtime_shape, value in self.cache.items(): - for graph_index, inductor_artifact in value.items(): - data.append( - (runtime_shape, graph_index, inductor_artifact.hash_str, - inductor_artifact.file_path)) - printer = pprint.PrettyPrinter(indent=4) - return printer.pformat(data) + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache) def save_to_file(self): - if self.disabled: + if self.disable_cache: return with open(self.cache_file_path, "w") as f: - f.write(self.serialize()) - - def __contains__(self, key: Tuple[Optional[int], int]) -> bool: - if self.disabled: - return False - runtime_shape, graph_index = key - return runtime_shape in self.cache and graph_index in self.cache[ - runtime_shape] - - def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: - if self.disabled: - raise KeyError("cannot read from disabled cache") - runtime_shape, graph_index = key - return self.cache[runtime_shape][graph_index] - - def __setitem__(self, key: Tuple[Optional[int], int], - value: InductorArtifact): - # setitem for disabled cache is fine, because we - # don't actually write to the disk - runtime_shape, graph_index = key - self.cache[runtime_shape][graph_index] = value - - -class AlwaysHitShapeEnv: - """ - Why do we need this class: - - For normal `torch.compile` usage, every compilation will have - one Dynamo bytecode compilation and one Inductor compilation. - The Inductor compilation happens under the context of the - Dynamo bytecode compilation, and that context is used to - determine the dynamic shape information, etc. - - For our use case, we only run Dynamo bytecode compilation once, - and run Inductor compilation multiple times with different shapes - plus a general shape. The compilation for specific shapes happens - outside of the context of the Dynamo bytecode compilation. At that - time, we don't have shape environment to provide to Inductor, and - it will fail the Inductor code cache lookup. - - By providing a dummy shape environment that always hits, we can - make the Inductor code cache lookup always hit, and we can - compile the graph for different shapes as needed. - - The following dummy methods are obtained by trial-and-error - until it works. - """ - - def __init__(self) -> None: - self.guards: List[Any] = [] - - def evaluate_guards_expression(self, *args, **kwargs): - return True - - def get_pruned_guards(self, *args, **kwargs): - return [] - - def produce_guards_expression(self, *args, **kwargs): - return "" - - -def wrap_inductor(graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - vllm_backend: "VllmBackend", - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None, - use_inductor: bool = True) -> Any: - if graph_index == 0: - # before compiling the first graph, record the start time - global compilation_start_time - compilation_start_time = time.time() - - if not use_inductor: - return graph - - compilation_counter.num_inductor_compilations += 1 - - from torch._inductor import config - current_config = config.get_config_copy() - from torch._inductor.compile_fx import compile_fx - - if additional_inductor_config is not None: - current_config.update(additional_inductor_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True - - # inductor can inplace modify the graph, so we need to copy it - # see https://github.com/pytorch/pytorch/issues/138980 - graph = copy.deepcopy(graph) - - cache_data = vllm_backend.inductor_hash_cache - if (runtime_shape, graph_index) in cache_data: - # we compiled this graph before - # so we can directly lookup the compiled graph via hash - inductor_artifact = cache_data[(runtime_shape, graph_index)] - hash_str = inductor_artifact.hash_str - if graph_index == 0: - # adds some info logging for the first graph - logger.info( - "Directly lookup the graph for shape %s from the cache", - str(runtime_shape)) # noqa + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + f.write(data) + + def load(self, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Optional[Callable]: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) logger.debug( - "directly lookup the %s-th graph for shape %s via hash %s", - graph_index, str(runtime_shape), hash_str) - from torch._inductor.codecache import FxGraphCache - with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv()): - inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) - assert inductor_compiled_graph is not None, ( - "Inductor cache lookup failed. Please remove" - f"the cache file {cache_data.cache_file_path} and try again." # noqa - ) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - - # Inductor calling convention (function signature): - # f(list) -> tuple - # Dynamo calling convention (function signature): - # f(*args) -> Any - - # need to know if the graph returns a tuple - from torch._inductor.compile_fx import graph_returns_tuple - returns_tuple = graph_returns_tuple(graph) - - # this is the callable we return to Dynamo to run - def compiled_graph(*args): - # convert args to list - list_args = list(args) - graph_output = inductor_compiled_graph(list_args) - # unpack the tuple if needed - if returns_tuple: - return graph_output - else: - return graph_output[0] - else: - # it's the first time we compile this graph - # the assumption is that we don't have nested Inductor compilation. - # compiled_fx_graph_hash will only be called once, and we can hook - # it to get the hash of the compiled graph directly. - - inductor_artifact = InductorArtifact() - from torch._inductor.codecache import (FxGraphCache, - compiled_fx_graph_hash) - original_load = FxGraphCache.load - - def hijack_load(*args, **kwargs): - inductor_compiled_graph = original_load(*args, **kwargs) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - return inductor_compiled_graph - - def hijack_compiled_fx_graph_hash(*args, **kwargs): - out = compiled_fx_graph_hash(*args, **kwargs) - inductor_artifact.hash_str = out[0] - return out - - def _check_can_cache(*args, **kwargs): - # no error means it can be cached. - # Inductor refuses to cache the graph outside of Dynamo - # tracing context, and also disables caching for graphs - # with high-order ops. - # For vLLM, in either case, we want to cache the graph. - # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa - return - - def _get_shape_env() -> AlwaysHitShapeEnv: - return AlwaysHitShapeEnv() - - with ExitStack() as stack: - if not cache_data.disabled: - # compilation cache is enabled, patch several functions - - # hijack to get the compiled graph itself - stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache.load", - hijack_load)) - - # for hijacking the hash of the compiled graph - stack.enter_context( - patch("torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash)) - - # for providing a dummy shape environment - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env)) - - # for forcing the graph to be cached - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache)) - - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) - # store the inductor_artifact in the cache - cache_data[(runtime_shape, graph_index)] = inductor_artifact + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph + + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None) -> Any: if graph_index == 0: - # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s via hash %s from file %s", - graph_index, str(runtime_shape), inductor_artifact.hash_str, - inductor_artifact.file_path) - # after compiling the last graph, record the end time - if graph_index == num_graphs - 1: - now = time.time() - elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info("Compiling a graph for general shape takes %.2f s", - elapsed) - else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Directly load the compiled graph for shape %s " + "from the cache", str(runtime_shape)) # noqa + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) - return compiled_graph + return compiled_graph @dataclasses.dataclass @@ -436,16 +257,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_general_shape = wrap_inductor( + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( submod, args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=None) self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, @@ -483,7 +303,7 @@ class VllmBackend: post_grad_passes: Sequence[Callable] sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - inductor_hash_cache: InductorHashCache + compiler_manager: CompilerManager def __init__( self, @@ -507,6 +327,9 @@ class VllmBackend: self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config.use_inductor) + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -533,9 +356,11 @@ class VllmBackend: # the cache dir will be the same so that we can reuse the compiled # graph. + factors = [] # 1. factors come from the vllm_config (it mainly summarizes how the # model is created) config_hash = vllm_config.compute_hash() + factors.append(config_hash) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) @@ -553,10 +378,15 @@ class VllmBackend: import hashlib code_hash = hashlib.md5( "\n".join(hash_content).encode()).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10] - # combine the two hashes to generate the cache dir - hash_key = hashlib.md5( - f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_compile_cache", @@ -570,15 +400,16 @@ class VllmBackend: cache_dir, f"rank_{vllm_config.parallel_config.rank}") self.compilation_config.local_cache_dir = local_cache_dir - disabled = envs.VLLM_DISABLE_COMPILE_CACHE - self.inductor_hash_cache: InductorHashCache = InductorHashCache( - local_cache_dir, disabled=disabled) - if disabled: + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + + if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: logger.info("Using cache directory: %s for vLLM's torch.compile", local_cache_dir) + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 @@ -759,7 +590,7 @@ class PiecewiseBackend: if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run - self.vllm_backend.inductor_hash_cache.save_to_file() + self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: @@ -782,16 +613,14 @@ class PiecewiseBackend: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = wrap_inductor( + entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=runtime_shape) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py new file mode 100644 index 000000000..ac0544ad6 --- /dev/null +++ b/vllm/compilation/compiler_interface.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +from vllm.config import VllmConfig + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ + pass + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the VLLM config, + to compute a hash so that we can cache the compiled model. + + See :meth:`VllmConfig.compute_hash` to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + """ + return None, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5 and 2.6. + """ + name = "inductor" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + from torch._inductor import config + current_config = config.get_config_copy() + from torch._inductor.compile_fx import compile_fx + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + if compiler_config is not None: + current_config.update(compiler_config) + + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + current_config["max_autotune"] = True + current_config["coordinate_descent_tuning"] = True + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._inductor.codecache import FxGraphCache + with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv()): + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + +class EagerAdaptor(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. + return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index a6f11a3af..5be452593 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -13,7 +13,7 @@ class CompilationCounter: num_piecewise_graphs_seen: int = 0 # not including the splitting ops num_piecewise_capturable_graphs_seen: int = 0 - num_inductor_compilations: int = 0 + num_backend_compilations: int = 0 num_cudagraph_caputured: int = 0 def clone(self) -> "CompilationCounter": diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index be663946f..1fea927aa 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -13,7 +13,6 @@ from torch import fx class InductorPass(ABC): """ General custom inductor pass interface. - TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass """ @abstractmethod diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index c7387fb7c..52f8c3b1e 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List +import torch from torch import fx as fx from vllm.config import CompilationConfig @@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass logger = init_logger(__name__) -class PostGradPassManager: +class PlaceHolder: + pass + + +if torch.__version__ < "2.6": + Parent = PlaceHolder # type: ignore +else: + Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore + + +class PostGradPassManager(Parent): """ The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. @@ -55,6 +66,9 @@ class PostGradPassManager: assert isinstance(pass_, InductorPass) self.passes.append(pass_) + def uuid(self): + return self.__getstate__() + def __getstate__(self) -> Dict[str, List[Any]]: """ Custom pickling for the pass manager, as some passes cannot be pickled. diff --git a/vllm/config.py b/vllm/config.py index 9ba497576..5579d6936 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3072,15 +3072,6 @@ class VllmConfig: the final hidden states. """ factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) # summarize vllm config vllm_factors: List[Any] = [] -- GitLab From afe74f7a969132ec17589e51edb645b894439c0a Mon Sep 17 00:00:00 2001 From: Jitse Klomp Date: Thu, 6 Feb 2025 18:17:55 +0100 Subject: [PATCH 008/253] [Doc] double quote cmake package in build.inc.md (#12840) --- docs/source/getting_started/installation/cpu/build.inc.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/installation/cpu/build.inc.md b/docs/source/getting_started/installation/cpu/build.inc.md index f8d1044a0..2a8173803 100644 --- a/docs/source/getting_started/installation/cpu/build.inc.md +++ b/docs/source/getting_started/installation/cpu/build.inc.md @@ -10,7 +10,7 @@ Second, install Python packages for vLLM CPU backend building: ```console pip install --upgrade pip -pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy +pip install "cmake>=3.26" wheel packaging ninja "setuptools-scm>=8" numpy pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu ``` -- GitLab From 8108ac841d66515b58252edc26ba63da6cf980e5 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 7 Feb 2025 01:18:22 +0800 Subject: [PATCH 009/253] [Bugfix] Fix unsupported FA version check for Turing GPU (#12828) --- vllm/attention/backends/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3c5028a66..e8a344341 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -612,5 +612,5 @@ try: return fa_version VLLM_FLASH_ATTN_VERSION = flash_attn_version() -except ImportError: +except (ImportError, AssertionError): VLLM_FLASH_ATTN_VERSION = None -- GitLab From 467a96a5415dc896170cecc0bb83d9c49c2f3c5e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 6 Feb 2025 23:02:51 +0530 Subject: [PATCH 010/253] [V1] LoRA Support (#10957) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/lora/conftest.py | 17 +++ tests/lora/test_baichuan.py | 8 ++ tests/lora/test_chatglm3_tp.py | 13 ++ tests/lora/test_gemma.py | 8 ++ tests/lora/test_llama_tp.py | 12 ++ tests/lora/test_lora_bias_e2e.py | 11 ++ tests/lora/test_phi.py | 13 ++ tests/lora/test_quant_model.py | 8 ++ tests/v1/core/test_kv_cache_utils.py | 2 +- vllm/lora/layers.py | 8 +- .../model_executor/layers/logits_processor.py | 28 ++-- vllm/v1/core/kv_cache_utils.py | 101 ++++++++++---- vllm/v1/core/scheduler.py | 32 ++++- vllm/v1/worker/gpu_input_batch.py | 63 ++++++++- vllm/v1/worker/gpu_model_runner.py | 56 ++++++-- vllm/v1/worker/lora_model_runner_mixin.py | 129 ++++++++++++++++++ 16 files changed, 453 insertions(+), 56 deletions(-) create mode 100644 vllm/v1/worker/lora_model_runner_mixin.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 071cdbecc..5ea66518b 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -306,3 +306,20 @@ def llama_2_7b_engine_extra_embeddings(): def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model) + + +@pytest.fixture(params=[True, False]) +def run_with_both_engines_lora(request, monkeypatch): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + monkeypatch.setenv('VLLM_USE_V1', '1') + else: + monkeypatch.setenv('VLLM_USE_V1', '0') + + yield diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 249f7619d..d39925948 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_baichuan_lora(baichuan_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 0aa9fe7a9..ee09afe86 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -2,6 +2,8 @@ from typing import List +import pytest + import vllm from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest @@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.mark.skip_v1 @fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, @@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4(chatglm3_lora_files): @@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 8923aa221..a1b4c897c 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 39f779f40..564818f23 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -2,6 +2,7 @@ from typing import List +import pytest import ray import vllm @@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files): print("removing lora") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_llama_lora(sql_lora_files): @@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files): generate_and_test(llm, sql_lora_files) +# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks +# used by the engine yet. +@pytest.mark.skip_v1 @fork_new_process_for_each_test def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index cbdd68831..3a7b39169 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +# Skipping for V1 for now as we are hitting, +# "Head size 80 is not supported by FlashAttention." error. +@pytest.mark.skip_v1 @pytest.mark.parametrize("lora_bias", [True]) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 651c89ffc..8999e0cf3 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -2,6 +2,8 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +# Skipping for V1 for now as we are hitting, +# "Head size 80 is not supported by FlashAttention." error. +@pytest.mark.skip_v1 def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5702aa26b..7f687f563 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -70,6 +70,14 @@ def do_sample(llm: vllm.LLM, return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 60cf4384d..8df4cbe1b 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys(): # Test with no overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0) - assert extra_keys == () + assert extra_keys is None assert next_mm_idx == 1 # Test with multiple extra keys diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9f0297596..9826aeb9d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) + tensor_model_parallel_all_reduce) from vllm.distributed.utils import divide # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -1043,7 +1042,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + if logits is None: return None diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index cdc67ca83..0565c6e8b 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -51,7 +51,6 @@ class LogitsProcessor(nn.Module): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config self.use_all_gather = current_platform.is_tpu() \ or envs.VLLM_USE_V1 \ @@ -88,6 +87,20 @@ class LogitsProcessor(nn.Module): return logits + def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: + """gather/all-gather the logits tensor across model parallel group.""" + if self.use_all_gather: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + return logits + def _get_logits( self, hidden_states: torch.Tensor, @@ -99,16 +112,9 @@ class LogitsProcessor(nn.Module): hidden_states, bias=embedding_bias) - if self.use_all_gather: - # Gather is not supported for some devices such as TPUs. - # Use all-gather instead. - # NOTE(woosuk): Here, the outputs of every device should not be None - # because XLA requires strict SPMD among all devices. Every device - # should execute the same operations after gathering the logits. - logits = tensor_model_parallel_all_gather(logits) - else: - # None may be returned for rank > 0 - logits = tensor_model_parallel_gather(logits) + # Gather logits for TP + logits = self._gather_logits(logits) + # Remove paddings in vocab (if any). if logits is not None: logits = logits[..., :self.org_vocab_size] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e0976ba85..6888f1a3e 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -170,14 +170,28 @@ class FreeKVCacheBlockQueue: return ret -def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: - """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). - For multi-modal inputs, the extra keys are (mm_hash, start_offset) that - indicate a mm input contained in the block and its starting offset in - the block tokens. +def need_extra_keys(request: Request) -> bool: + """Check whether the blocks allocated to this request need extra hash keys. + + Args: + request (Request): The request. + + Returns: + bool: Whether blocks allocated to this request need extra hash keys. + """ + + # Multimodal requests need to include the MM hash. + # LoRA requests need to include the LoRA ID. + return bool(request.mm_positions) or (request.lora_request is not None) + + +def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, + end_token_idx: int, + start_mm_idx: int) -> Tuple[List[Any], int]: + """Generate extra keys related to MultiModal request for block hash + computation. For multi-modal inputs, the extra keys are + (mm_hash, start_offset) that indicate a mm input contained in the + block and its starting offset in the block tokens. Args: request: The request object. @@ -188,10 +202,11 @@ def generate_block_hash_extra_keys( Returns: A tuple of extra keys and the next multi-modal index. """ + extra_keys: List[Any] = [] mm_positions, mm_hashes = request.mm_positions, request.mm_hashes if not mm_positions: - return None, start_mm_idx + return extra_keys, start_mm_idx if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( @@ -204,14 +219,13 @@ def generate_block_hash_extra_keys( # range. This usually happens in the late prefill phase and decoding phase. if mm_positions[-1]["offset"] + mm_positions[-1][ "length"] < start_token_idx: - return None, start_mm_idx + return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: assert -start_mm_idx <= len(mm_positions) start_mm_idx = len(mm_positions) + start_mm_idx - extra_keys = [] curr_mm_idx = start_mm_idx while mm_positions and curr_mm_idx < len(mm_positions): assert mm_hashes[curr_mm_idx] is not None @@ -237,7 +251,50 @@ def generate_block_hash_extra_keys( else: # This block has not reached the current mm input. break - return tuple(extra_keys), curr_mm_idx + return extra_keys, curr_mm_idx + + +def _gen_lora_extra_hash_keys(request: Request) -> List[int]: + """Generate extra keys related to LoRA for block hash computation. + + Args: + request: The request object. + + Returns: + Return LoRA id of the request if it is a LoRA request. Return empty + list otherwise. + """ + if not request.lora_request: + return [] + return [request.lora_request.lora_int_id] + + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + mm_extra_keys: List[Any] + mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( + request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request) + + extra_keys: List[Any] = lora_extra_keys + mm_extra_keys + + if not extra_keys: + return None, new_start_mm_idx + + return tuple(extra_keys), new_start_mm_idx def hash_block_tokens( @@ -249,9 +306,6 @@ def hash_block_tokens( prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - TODO: Support arbitrary metadata so that we could support more - features such as LoRA adapter. - Args: parent_block_hash: The hash of the parent block. None if this is the first block. @@ -291,14 +345,9 @@ def hash_request_tokens(block_size: int, The list of computed hash values. """ token_ids = request.all_token_ids - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match.") - # TODO: Extend this to support other features such as LoRA. - need_extra_keys = bool(mm_positions) - extra_keys = None + req_need_extra_keys = need_extra_keys(request) + req_extra_keys = None curr_mm_idx = 0 ret = [] @@ -310,13 +359,13 @@ def hash_request_tokens(block_size: int, if len(block_token_ids) < block_size: break - # Add extra keys if the block is a multi-modal block. - if need_extra_keys: - extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + if req_need_extra_keys: + # MM and LoRA requests need extra keys for block-hash computation. + req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) block_hash = hash_block_tokens(parent_block_hash_value, - block_token_ids, extra_keys) + block_token_ids, req_extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index fb5e83fe0..6c44fec64 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) @@ -35,8 +36,6 @@ class Scheduler: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -180,6 +179,14 @@ class Scheduler: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: Set[int] = set() + if self.lora_config: + requested_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting and token_budget > 0: @@ -187,6 +194,23 @@ class Scheduler: break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and ( + req_lora_id not in requested_loras): + # Cannot schedule. + # TODO (varun): This means all the other requests in + # the WAITING queue will be blocked by this request, + # even if, + # 1. these other requests do not use LoRA, or, + # 2. these other requests use the already requested + # LoRAs. + # This is too conservative and could be optimized. + break + # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) @@ -234,6 +258,8 @@ class Scheduler: raise RuntimeError( f"Invalid request status: {request.status}") + if self.lora_config and request.lora_request: + requested_loras.add(request.lora_request.lora_int_id) req_to_new_block_ids[request.request_id] = [ b.block_id for b in computed_blocks + new_blocks ] @@ -568,6 +594,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -586,6 +613,7 @@ class NewRequestData: sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 39708f833..a31e88865 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,11 +3,12 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata @@ -35,6 +36,8 @@ class CachedRequestState: mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None + lora_request: Optional[LoRARequest] = None + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) @@ -161,6 +164,12 @@ class InputBatch: ] self.prompt_token_ids: Optional[torch.Tensor] = None + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_id_to_request_ids: Dict[int, Set[str]] = {} + self.lora_id_to_lora_request: Dict[int, LoRARequest] = {} + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -235,6 +244,19 @@ class InputBatch: if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -251,6 +273,16 @@ class InputBatch: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + self.lora_id_to_request_ids[lora_id].discard(req_id) + if len(self.lora_id_to_request_ids[lora_id]) == 0: + self.lora_id_to_request_ids.pop(lora_id) + self.lora_id_to_lora_request.pop(lora_id) + self.request_lora_mapping[req_index] = 0 + return req_index def clear(self) -> None: @@ -266,6 +298,9 @@ class InputBatch: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.request_lora_mapping.fill(0) + self.lora_id_to_lora_request.clear() + self.lora_id_to_request_ids.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -318,6 +353,9 @@ class InputBatch: if generator is not None: self.generators[empty_index] = generator + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -401,6 +439,29 @@ class InputBatch: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: Set[LoRARequest] = set( + self.lora_id_to_lora_request.values()) + + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ec6d04cd4..bfc9d1ca8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -40,7 +41,7 @@ if TYPE_CHECKING: logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -279,6 +280,7 @@ class GPUModelRunner: block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], + lora_request=new_req_data.lora_request, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -372,15 +374,16 @@ class GPUModelRunner: # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] + num_scheduled_tokens_list: List[int] = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) + num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) assert max_num_scheduled_tokens > 0 # Get request indices. @@ -565,6 +568,11 @@ class GPUModelRunner: prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, ) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # NOTE(woosuk): Due to chunked prefills, the batch may contain partial # requests. While we should not sample any token from these partial # requests, we do so for simplicity. We will ignore the sampled @@ -867,6 +875,12 @@ class GPUModelRunner: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1005,14 +1019,32 @@ class GPUModelRunner: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits - self.encoder_cache.clear() + # For profile, have maximum num_reqs and that collectively have + # maximum num_tokens. + num_reqs = self.scheduler_config.max_num_seqs + num_tokens = self.max_num_tokens + min_tokens_per_req: int = num_tokens // num_reqs + + num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + with self.maybe_profile_with_lora(self.lora_config, + num_scheduled_tokens): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.max_num_tokens, + dummy_kv_caches) + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 000000000..e7501ad2e --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Define LoRA functionality mixin for model runners. +""" + +from contextlib import contextmanager +from typing import Set, Tuple + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], + token_lora_mapping: Tuple[int, ...], + lora_requests: Set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # We dont make any distinction between prefills and decodes in the + # scheduler. To that effect, set is_prefill to True so we use the + # sgmv punica kernels always. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=True) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.ndarray) -> None: + + prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: Tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: Set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_profile_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs cyclically to simulate a worst-case scenario. + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % + num_loras) + 1 + + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: Set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() -- GitLab From aff404571b0d5aba342c46fdf5d7f8a251da9383 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 7 Feb 2025 07:22:42 +0800 Subject: [PATCH 011/253] Add Bamba Model (#10909) Signed-off-by: Yu Chin Fabian Lim Signed-off-by: Tyler Michael Smith Co-authored-by: Tyler Michael Smith --- tests/kernels/test_mamba_mixer2.py | 125 +++ tests/kernels/test_mamba_ssm_ssd.py | 304 +++++++ .../{test_jamba.py => test_hybrid.py} | 35 +- tests/models/registry.py | 1 + vllm/attention/backends/placeholder_attn.py | 140 ++-- .../layers/mamba/mamba_mixer2.py | 534 +++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 2 +- .../layers/mamba/ops/ssd_bmm.py | 261 ++++++ .../layers/mamba/ops/ssd_chunk_scan.py | 615 ++++++++++++++ .../layers/mamba/ops/ssd_chunk_state.py | 750 ++++++++++++++++++ .../layers/mamba/ops/ssd_combined.py | 223 ++++++ .../layers/mamba/ops/ssd_state_passing.py | 207 +++++ vllm/model_executor/models/bamba.py | 592 ++++++++++++++ vllm/model_executor/models/jamba.py | 11 +- vllm/model_executor/models/mamba.py | 10 +- vllm/model_executor/models/mamba_cache.py | 7 +- vllm/model_executor/models/registry.py | 1 + 17 files changed, 3706 insertions(+), 112 deletions(-) create mode 100644 tests/kernels/test_mamba_mixer2.py create mode 100644 tests/kernels/test_mamba_ssm_ssd.py rename tests/models/decoder_only/language/{test_jamba.py => test_hybrid.py} (91%) create mode 100644 vllm/model_executor/layers/mamba/mamba_mixer2.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_bmm.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_combined.py create mode 100644 vllm/model_executor/layers/mamba/ops/ssd_state_passing.py create mode 100644 vllm/model_executor/models/bamba.py diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/test_mamba_mixer2.py new file mode 100644 index 000000000..8c441fcbe --- /dev/null +++ b/tests/kernels/test_mamba_mixer2.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +from typing import Tuple + +import pytest +import torch + +from tests.utils import multi_gpu_test +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize( + "hidden_size_n_groups", + [ + (64, 1), + (64, 2), + (64, 4), # hidden_size be divisible by num_gpus + (100, 5), # and n_groups must divide hidden_size + ]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_mixer2_gated_norm_multi_gpu( + batch_size: int, + seq_len: int, + hidden_size_n_groups: Tuple[int, int], + dtype: torch.dtype, + device: str = 'cuda', +): + hidden_size, n_groups = hidden_size_n_groups + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs) + + run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) + + +def mixer2_gated_norm_tensor_parallel( + local_rank: int, + world_size: int, + batch_size: int, + seq_len: int, + hidden_size: int, + n_groups: int, + dtype: torch.dtype, + device: str, +): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # create random weights an inputs + weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + gate_states = torch.randn(batch_size, seq_len, hidden_size) + + # create gated-norm with TP + mixer = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + mixer.weight.weight_loader(mixer.weight, weight) # load + + # create gated-norm without TP to compute reference + # - utilize mock patching to disable TP when + with (unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_world_size", + return_value=1), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0)): + mixer_single_gpu = Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + # assign weight to single-gpu mixer + mixer_single_gpu.weight.data = weight + + # generate and compare + N = hidden_size // world_size + output = mixer( + hidden_states[..., local_rank * N:(local_rank + 1) * N], + gate_states[..., local_rank * N:(local_rank + 1) * N], + ) + ref_output = mixer_single_gpu(hidden_states, gate_states) + torch.allclose(output, + ref_output[..., local_rank * N:(local_rank + 1) * N], + atol=1e-3, + rtol=1e-3) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py new file mode 100644 index 000000000..882513116 --- /dev/null +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Tuple + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.platforms import current_platform + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py + + +# this is the segsum implementation taken from above +def segsum(x): + """Calculates segment sum.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), + diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) + for x in (X, A, B, C)) + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +def generate_random_inputs(batch_size, + seqlen, + n_heads, + d_head, + itype, + device='cuda'): + + current_platform.seed_everything(0) + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + dt = F.softplus( + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - + 4) + X = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), + dtype=itype, + device=device) + + return A, dt, X, B, C + + +def generate_continous_batched_examples(example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device='cuda'): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // 4) + + # internal function that outputs a cont batch of examples + # given a tuple of lengths for each example in the batch + # e.g., example_lens=(8, 4) means take 8 samples from first eg, + # 4 examples from second eg, etc + def get_continuous_batch(example_lens: Tuple[int, ...]): + + indices = [] + for i, x in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + x)) + last_taken[i] = (c + x) % full_length + exhausted[i] = last_taken[i] == 0 + + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) + ]).unsqueeze(0) for x in (dt, X, B, C)) + + # internal function that maps "n" to the appropriate right boundary + # value when forming continuous batches from examples of length given + # by "full_length". + # - e.g., when n > full_length, returns n % full_length + # when n == full_length, returns full_length + def end_boundary(n: int): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for spec in example_lens_by_batch: + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = get_continuous_batch(spec) + + # get the metadata + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) + sed_idx = torch.zeros(cu_seqlens[-1], + dtype=torch.int32, + device=cu_seqlens.device) + for i, (srt, end) in enumerate(zip( + cu_seqlens, + cu_seqlens[1:], + )): + sed_idx[srt:end] = i + + # for cont batch + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, + itype): + + # this tests the kernels on a single example (no batching) + + # set seed + batch_size = 1 # batch_size + # ssd_minimal_discrete requires chunk_size divide seqlen + # - this is only required for generating the reference seqs, + # it is not an operational limitation. + seqlen, chunk_size = seq_len_chunk_size + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, + B, C, chunk_size) + + Y, final_state = mamba_chunk_scan_combined(X, + dt, + A, + B, + C, + chunk_size, + D=None, + return_final_states=True) + + # just test the last in sequence + torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.allclose(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=1e-3, + rtol=1e-3) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("n_heads", [4, 8, 13]) +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize( + "seq_len_chunk_size_cases", + [ + + # small-ish chunk_size (8) + (64, 8, 2, [(64, 32), (64, 32)]), + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary + (64, 8, 2, [(4, 4), (4, 4), (4, 4), + (4, 4)]), # chunk_size larger than cont batches + (64, 8, 5, [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ]), # mode examples with varied lengths + + # odd chunk_size + (64, 29, 2, [(11, 4), (13, 23), (19, 22), + (21, 15)]), # irregular sizes + + # large-ish chunk_size (256) + (64, 256, 1, [(5, ), (1, ), (1, ), + (1, )]), # irregular sizes with small sequences + (64, 256, 2, [(5, 30), (1, 2), (1, 2), + (1, 2)]), # irregular sizes with small sequences + ]) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, + itype): + + # this test with multiple examples in a continuous batch + # (i.e. chunked prefill) + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: Dict = {} # map: eg -> pointer to last taken sample + exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted + + states = None + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + C) in generate_continous_batched_examples( + cases, num_examples, seqlen, + last_taken, exhausted, n_heads, + d_head, itype): + + Y, new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=sed_idx, + return_varlen_states=True, + initial_states=states, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + + # update states + states = new_states + for i, clear in exhausted.items(): + if clear: + states[i].fill_(0.) + exhausted[i] = False diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_hybrid.py similarity index 91% rename from tests/models/decoder_only/language/test_jamba.py rename to tests/models/decoder_only/language/test_hybrid.py index cc98f1d7b..a39b11923 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -8,7 +8,8 @@ from vllm.sampling_params import SamplingParams from ...utils import check_outputs_equal -MODELS = ["ai21labs/Jamba-tiny-dev"] +# This test is for the hybrid models +MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] @pytest.mark.parametrize("model", MODELS) @@ -23,6 +24,10 @@ def test_models( max_tokens: int, ) -> None: + # numeric error produces different generation + if 'Bamba' in model: + example_prompts.pop(3) + with hf_runner( model, dtype=dtype, @@ -108,15 +113,21 @@ def test_mamba_prefill_chunking_with_parallel_sampling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("max_tokens", [7]) def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int) -> None: # numeric error during prefill chucking produces different generation # compared to w/o prefill chunking for those examples, removed them for now - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) + if 'Jamba' in model: + example_prompts.pop(7) + example_prompts.pop(2) + example_prompts.pop(1) + elif 'Bamba' in model: + example_prompts.pop(6) + example_prompts.pop(3) + example_prompts.pop(2) + dtype = "half" # use a different dtype for Bamba with hf_runner( model, @@ -145,7 +156,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [15]) def test_parallel_sampling( vllm_runner, @@ -249,17 +260,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba inner state management doesn't + # This test is for verifying that the hybrid inner state management doesn't # collapse in case where the number of incoming requests and # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that Jamba does support + # This could generally happen due to the fact that hybrid does support # statelessness mechanism where it can cleanup new incoming requests in # a single step. try: with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up properly between" + pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") @@ -271,14 +282,14 @@ def test_state_cleanup( dtype: str, example_prompts, ) -> None: - # This test is for verifying that the Jamba state is cleaned up between + # This test is for verifying that the Hybrid state is cleaned up between # steps, If its not cleaned, an error would be expected. try: with vllm_runner(model, dtype=dtype) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Jamba inner state wasn't cleaned up between states, " + pytest.fail("Hybrid inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") @@ -324,7 +335,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) -def test_jamba_distributed_produces_identical_generation( +def test_hybrid_distributed_produces_identical_generation( vllm_runner, model: str, dtype: str, max_tokens: int, example_prompts) -> None: diff --git a/tests/models/registry.py b/tests/models/registry.py index 20787fe00..3fd94b89c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -102,6 +102,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), # ChatGLMModel supports multimodal "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 9f6e731af..f363ba0c1 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -15,6 +16,7 @@ from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that # lack attention. @@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @@ -125,11 +123,17 @@ class PlaceholderAttentionMetadata(AttentionMetadata): if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.seq_start_loc is not None + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) # Placeholders slot_mapping = torch.empty(0) @@ -143,15 +147,15 @@ class PlaceholderAttentionMetadata(AttentionMetadata): multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -169,6 +173,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata): # Placeholders slot_mapping = torch.empty(0) block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, @@ -178,13 +184,16 @@ class PlaceholderAttentionMetadata(AttentionMetadata): multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, @@ -235,8 +244,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): assert self.context_lens_tensor is not None assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None - # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size for i in range(num_queries): @@ -299,9 +306,6 @@ class PlaceholderAttentionMetadataBuilder( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -323,15 +327,6 @@ class PlaceholderAttentionMetadataBuilder( device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] if len(decode_query_lens) > 0: @@ -341,48 +336,37 @@ class PlaceholderAttentionMetadataBuilder( max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: - num_decode_tokens = batch_size - + num_decode_tokens = batch_size - self.num_prefill_tokens assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) # Placeholders - slot_mapping = torch.empty(0) + slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, @@ -393,8 +377,8 @@ class PlaceholderAttentionMetadataBuilder( max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 000000000..5fd126491 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,534 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + assert self.full_hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = (global_sums / count) + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if self.tp_size > 1 or self.n_groups != 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + return tp_size - ngroups % tp_size + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, ratio in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + rank = tp_rank // ratio + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary:(boundary + take), # type: ignore[misc] + ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] + loaded_start_idx + take)] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += (full_dim - extra) + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation="silu", + chunk_size: int = 256, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + + self.ssm_state_size = ssm_state_size + self.activation = activation + + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = (intermediate_size + + 2 * self.n_groups * ssm_state_size) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size + + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + self.num_heads // + n_groups, # ratio for mapping back to original group + ) + intermediate_settings = (intermediate_size, 0, 1) + head_setings = (self.num_heads, 0, 1) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], self.tp_size, tp_rank) + }) + + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_setings, # for dt + ], + self.tp_size, + tp_rank) + }) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config) + + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, + eps=rms_norm_eps) + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + ): + + seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + + # detect if there are prefills + has_prefill = attn_metadata.num_prefills > 0 + + # - also need flags to indicate if there are initial states + # - currently we really only support the FlashAttention backend + has_initial_states = None + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata)) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if has_prefill: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + hidden_states_B_C = causal_conv1d_fn( + hidden_states_B_C.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=has_initial_states, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc).transpose( + 0, 1)[:seq_len] + + # TODO: Why is this needed? + hidden_states_B_C = hidden_states_B_C.contiguous() + else: + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + + # - get hidden_states, B and C after depthwise convolution. + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + # 3. State Space Model sequence transformation + if has_prefill: + + initial_states = None + if has_initial_states is not None and any(has_initial_states): + for idx in mamba_cache_params.state_indices_tensor[ + ~has_initial_states]: + mamba_cache_params.ssm_state[idx].zero_() + initial_states = mamba_cache_params.ssm_state[ + mamba_cache_params.state_indices_tensor] + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states.view(1, seq_len, self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, seq_len, self.n_groups // self.tp_size, -1), + C.view(1, seq_len, self.n_groups // self.tp_size, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=sequence_idx, + cu_seqlens=attn_metadata.query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + for i, idx in enumerate(mamba_cache_params.state_indices_tensor): + mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) + + # - reshape + hidden_states = scan_output.view(seq_len, -1) + else: + + n_groups = self.n_groups // self.tp_size + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states_reshaped = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into number of current batches + # - in this case there is no more prefill, so the batches gen + # 1 token at a time + # - thus hidden will be (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using "mamba_cache_params.state_indices_tensor", just as + # above in the prefill case + + hidden_states = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor, + ) + hidden_states = hidden_states.view( + -1, (self.num_heads // self.tp_size) * self.head_dim) + + # # 4. gated MLP + hidden_states = self.norm(hidden_states, gate) + + # # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 3c35f1ac0..b31b980fb 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 000000000..388a63327 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 000000000..722fbd714 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,615 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): + + # convert seq_idx to chunk indices and offsets + # - derive the cu_seqlens + _, cu_seqlens = torch.where(seq_idx.diff()) + cu_seqlens += 1 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) + chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + + cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + _s, _e = s // chunk_size + p, e // chunk_size + p + 1 + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + chunk_indices, chunk_offsets = None, None + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 000000000..a970ac945 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,750 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 000000000..97cdb70b6 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert x.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) has_cu_seqlens to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 000000000..d8f87c113 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + assert initial_states.shape == (seq_idx.max().item() + 1, nheads, + dim) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 000000000..72b74e31b --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,592 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + chunk_size=config.mamba_chunk_size, + quant_config=quant_config) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params, sequence_idx) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + if isinstance(layer, BambaAttentionDecoderLayer): + kv_cache = kv_caches[num_attn] + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, BambaMixerDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + sequence_idx=seq_idx, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # follow jamba + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + # for compilation + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # for eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index d82c08152..f307f279d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -455,14 +455,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 5034b3345..3bbc219e9 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -232,15 +232,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): self.lm_head.weight.dtype, num_mamba_layers, self.max_batch_size, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 353177f78..ce4197507 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -5,7 +5,6 @@ from typing import Dict, List import torch -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.utils import PAD_SLOT_ID @@ -42,8 +41,7 @@ class MambaCacheManager: self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.free_cache_indices = list(range(max_batch_size)) - def current_run_tensors(self, input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs): + def current_run_tensors(self, **kwargs) -> MambaCacheParams: """ Return the tensors for the current run's conv and ssm state. """ @@ -66,7 +64,8 @@ class MambaCacheManager: (mamba_cache_tensors, state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - return (mamba_cache_tensors, state_indices_tensor) + return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], + state_indices_tensor) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069e..c2d0fae70 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -37,6 +37,7 @@ _TEXT_GENERATION_MODELS = { "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), # ChatGLMModel supports multimodal "CohereForCausalLM": ("commandr", "CohereForCausalLM"), -- GitLab From 741429a4cd4443001264a2c89c0150c12c2bd750 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Thu, 6 Feb 2025 15:36:21 -0800 Subject: [PATCH 012/253] [MISC] Check space in the file names in the pre commit checks (#12804) Signed-off-by: Lu Fang --- .pre-commit-config.yaml | 6 ++++++ ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 ...e=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} | 0 14 files changed, 6 insertions(+) rename vllm/model_executor/layers/quantization/utils/configs/{N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) rename vllm/model_executor/layers/quantization/utils/configs/{N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json => N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4568efcbb..0b1c4fdf2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -108,3 +108,9 @@ repos: language: system verbose: true pass_filenames: false + - id: check-filenames + name: Check for spaces in all filenames + entry: bash -c 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json rename to vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json -- GitLab From b26078235722a434d92fe90dcea2023a5ae7294a Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 6 Feb 2025 16:29:12 -0800 Subject: [PATCH 013/253] [misc] Revert # 12833 (#12857) Signed-off-by: <> Co-authored-by: EC2 Default User --- vllm/inputs/preprocess.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 53f89996f..035e84cc0 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -260,6 +260,9 @@ class InputPreprocessor: mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) + if isinstance(prompt, list): + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: mm_processor_kwargs = {} -- GitLab From ef533d25fba4b5ef8b4da9369de718c0773b9bce Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 22:54:07 -0500 Subject: [PATCH 014/253] [Bugfix] FA2 illegal memory access (#12848) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c823c9ff8..b99061dfd 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -581,7 +581,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c + GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn -- GitLab From 433c4a49230a470f13657f06e7612cde86e4fb40 Mon Sep 17 00:00:00 2001 From: ZSL98 <36250440+ZSL98@users.noreply.github.com> Date: Fri, 7 Feb 2025 11:54:20 +0800 Subject: [PATCH 015/253] Make vllm compatible with verl (#12824) Co-authored-by: zhangshulai --- vllm/distributed/parallel_state.py | 7 ------- vllm/executor/uniproc_executor.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 321902d11..bfc41703b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1024,13 +1024,6 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index dcb4a8f27..e5464cafa 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -101,7 +101,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): # - MASTER_PORT distributed_init_method = "env://" rank = int(os.environ["RANK"]) - local_rank = rank + local_rank = int(os.environ["LOCAL_RANK"]) is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, -- GitLab From aa375dca9fbeff03904cd7b7dcc5014bfa19b0fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Fri, 7 Feb 2025 06:35:09 +0100 Subject: [PATCH 016/253] [Bugfix] Missing quant_config in deepseek embedding layer (#12836) --- vllm/model_executor/models/deepseek_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0c6f07ce7..fd0e58fa1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -581,7 +581,8 @@ class DeepseekV2Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - ) + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") else: self.embed_tokens = PPMissingLayer() -- GitLab From 6e1fc61f0fb90c37f0d4a1a8f76235a6e4e1103c Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Fri, 7 Feb 2025 02:37:41 -0300 Subject: [PATCH 017/253] Prevent unecessary requests to huggingface hub (#12837) --- .../offline_mode/test_offline_mode.py | 21 ++++ vllm/transformers_utils/config.py | 115 ++++++++++++------ 2 files changed, 96 insertions(+), 40 deletions(-) diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index eac76f2ba..85156d693 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -4,6 +4,7 @@ import importlib import sys import pytest +import urllib3 from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory @@ -28,6 +29,15 @@ MODEL_CONFIGS = [ "tensor_parallel_size": 1, "tokenizer_mode": "mistral", }, + { + "model": "sentence-transformers/all-MiniLM-L12-v2", + "enforce_eager": True, + "gpu_memory_utilization": 0.20, + "max_model_len": 64, + "max_num_batched_tokens": 64, + "max_num_seqs": 64, + "tensor_parallel_size": 1, + }, ] @@ -47,6 +57,16 @@ def test_offline_mode(monkeypatch): # Set HF to offline mode and ensure we can still construct an LLM try: monkeypatch.setenv("HF_HUB_OFFLINE", "1") + monkeypatch.setenv("VLLM_NO_USAGE_STATS", "1") + + def disable_connect(*args, **kwargs): + raise RuntimeError("No http calls allowed") + + monkeypatch.setattr(urllib3.connection.HTTPConnection, "connect", + disable_connect) + monkeypatch.setattr(urllib3.connection.HTTPSConnection, "connect", + disable_connect) + # Need to re-import huggingface_hub and friends to setup offline mode _re_import_modules() # Cached model files should be used in offline mode @@ -56,6 +76,7 @@ def test_offline_mode(monkeypatch): # Reset the environment after the test # NB: Assuming tests are run in online mode monkeypatch.delenv("HF_HUB_OFFLINE") + monkeypatch.delenv("VLLM_NO_USAGE_STATS") _re_import_modules() pass diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 85056158b..fb5cc3ec0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,7 +10,7 @@ import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, try_to_load_from_cache) from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - LocalEntryNotFoundError, + HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) from torch import nn @@ -265,49 +265,66 @@ def get_config( return config +def try_get_local_file(model: Union[str, Path], + file_name: str, + revision: Optional[str] = 'main') -> Optional[Path]: + file_path = Path(model) / file_name + if file_path.is_file(): + return file_path + else: + try: + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=file_name, + revision=revision) + if isinstance(cached_filepath, str): + return Path(cached_filepath) + except HFValidationError: + ... + return None + + def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = 'main'): """ - Downloads a file from the Hugging Face Hub and returns + Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. Parameters: - file_name (str): The name of the file to download. - model (str): The name of the model on the Hugging Face Hub. - - revision (str): The specific version of the model. + - revision (str): The specific version of the model. Returns: - - config_dict (dict): A dictionary containing + - config_dict (dict): A dictionary containing the contents of the downloaded file. """ - file_path = Path(model) / file_name - if file_or_path_exists(model=model, - config_name=file_name, - revision=revision): + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) - if not file_path.is_file(): - try: - hf_hub_file = hf_hub_download(model, - file_name, - revision=revision) - except (RepositoryNotFoundError, RevisionNotFoundError, - EntryNotFoundError, LocalEntryNotFoundError) as e: - logger.debug("File or repository not found in hf_hub_download", - e) - return None - except HfHubHTTPError as e: - logger.warning( - "Cannot connect to Hugging Face Hub. Skipping file " - "download for '%s':", - file_name, - exc_info=e) - return None - file_path = Path(hf_hub_file) + if file_path is None and file_or_path_exists( + model=model, config_name=file_name, revision=revision): + try: + hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", e) + return None + except HfHubHTTPError as e: + logger.warning( + "Cannot connect to Hugging Face Hub. Skipping file " + "download for '%s':", + file_name, + exc_info=e) + return None + file_path = Path(hf_hub_file) + if file_path is not None and file_path.is_file(): with open(file_path) as file: return json.load(file) + return None @@ -328,7 +345,12 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): """ modules_file_name = "modules.json" - modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) + + modules_dict = None + if file_or_path_exists(model=model, + config_name=modules_file_name, + revision=revision): + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: return None @@ -382,17 +404,17 @@ def get_sentence_transformer_tokenizer_config(model: str, revision: Optional[str] = 'main' ): """ - Returns the tokenization configuration dictionary for a + Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. Parameters: - - model (str): The name of the Sentence Transformer + - model (str): The name of the Sentence Transformer BERT model. - revision (str, optional): The revision of the m odel to use. Defaults to 'main'. Returns: - - dict: A dictionary containing the configuration parameters + - dict: A dictionary containing the configuration parameters for the Sentence Transformer BERT model. """ sentence_transformer_config_files = [ @@ -404,20 +426,33 @@ def get_sentence_transformer_tokenizer_config(model: str, "sentence_xlm-roberta_config.json", "sentence_xlnet_config.json", ] - try: - # If model is on HuggingfaceHub, get the repo files - repo_files = list_repo_files(model, revision=revision, token=HF_TOKEN) - except Exception as e: - logger.debug("Error getting repo files", e) - repo_files = [] - encoder_dict = None - for config_name in sentence_transformer_config_files: - if config_name in repo_files or Path(model).exists(): - encoder_dict = get_hf_file_to_dict(config_name, model, revision) + + for config_file in sentence_transformer_config_files: + if try_get_local_file(model=model, + file_name=config_file, + revision=revision) is not None: + encoder_dict = get_hf_file_to_dict(config_file, model, revision) if encoder_dict: break + if not encoder_dict: + try: + # If model is on HuggingfaceHub, get the repo files + repo_files = list_repo_files(model, + revision=revision, + token=HF_TOKEN) + except Exception as e: + logger.debug("Error getting repo files", e) + repo_files = [] + + for config_name in sentence_transformer_config_files: + if config_name in repo_files: + encoder_dict = get_hf_file_to_dict(config_name, model, + revision) + if encoder_dict: + break + if not encoder_dict: return None -- GitLab From 1918aa1b8010c00443b71f8bb976d4db4acf3c18 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Fri, 7 Feb 2025 05:04:39 -0800 Subject: [PATCH 018/253] [MISC][EASY] Break check file names into entry and args in the pre-commit hooks (#12880) Signed-off-by: Lu Fang --- .pre-commit-config.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b1c4fdf2..3fb74ab9b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -110,7 +110,10 @@ repos: pass_filenames: false - id: check-filenames name: Check for spaces in all filenames - entry: bash -c 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' language: system always_run: true pass_filenames: false -- GitLab From ce26b16268ef8d7db5c1346c482b899f49dcd3cd Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 7 Feb 2025 22:21:17 +0800 Subject: [PATCH 019/253] [Misc] Remove unnecessary detokenization in multimodal processing (#12868) --- tests/entrypoints/openai/test_audio.py | 6 +++--- tests/entrypoints/openai/test_vision.py | 4 ++-- tests/entrypoints/openai/test_vision_embedding.py | 4 ++-- vllm/inputs/preprocess.py | 3 --- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 6e206dfd9..3459f2483 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -83,7 +83,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message @@ -140,7 +140,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message @@ -196,7 +196,7 @@ async def test_single_chat_session_input_audio( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=201, total_tokens=211) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 029c9b038..c954fca69 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -92,7 +92,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=775, total_tokens=785) + completion_tokens=10, prompt_tokens=774, total_tokens=784) message = choice.message message = chat_completion.choices[0].message @@ -185,7 +185,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=775, total_tokens=785) + completion_tokens=10, prompt_tokens=774, total_tokens=784) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index f2ff4a0b0..cee527456 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -93,5 +93,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 3072 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 764 - assert embeddings.usage.total_tokens == 764 + assert embeddings.usage.prompt_tokens == 763 + assert embeddings.usage.total_tokens == 763 diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 035e84cc0..53f89996f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -260,9 +260,6 @@ class InputPreprocessor: mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) - if isinstance(prompt, list): - prompt = tokenizer.decode(prompt) - if mm_processor_kwargs is None: mm_processor_kwargs = {} -- GitLab From 538fab93cdd36e965ea1888143dab0df57c8ba84 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 7 Feb 2025 06:22:37 -0800 Subject: [PATCH 020/253] PR #12718 (#12718) --- vllm/model_executor/layers/rotary_embedding.py | 18 +++++++++++------- vllm/model_executor/models/llama.py | 5 ++++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b3b9b0e87..ec204b32f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -509,15 +509,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ): super().__init__() - if rotary_dim != head_size: - raise ValueError( - f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ - rotary_dim != head_size ({rotary_dim}!={head_size}).") if is_neox_style is False: raise ValueError( "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." ) + self.rotary_dim = rotary_dim self.head_size = head_size self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings @@ -557,7 +554,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( - 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) return inv_freq def _compute_cos_sin_cache( @@ -596,8 +593,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): cos = cos.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2) - query = query * cos + _rotate_neox(query) * sin - key = key * cos + _rotate_neox(key) * sin + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = query_rot * cos + _rotate_neox(query_rot) * sin + query = torch.cat((query_rot, query_pass), dim=-1) + + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = key_rot * cos + _rotate_neox(key_rot) * sin + key = torch.cat((key_rot, key_pass), dim=-1) return query.flatten(-2), key.flatten(-2) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d91c8782a..866c69234 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -128,6 +128,9 @@ class LlamaAttention(nn.Module): # MistralConfig has an optional head_dim introduced by Mistral-Nemo self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) + # Phi models introduced a partial_rotary_factor parameter in the config + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(partial_rotary_factor * self.head_dim) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -159,7 +162,7 @@ class LlamaAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=self.rotary_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, -- GitLab From 0630d4537a5fbab80cb1109a26170101cffb7f84 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Fri, 7 Feb 2025 10:26:20 -0500 Subject: [PATCH 021/253] [V1] Logprobs and prompt logprobs support (#9880) This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman Signed-off-by: Nick Hill Signed-off-by: rshaw@neuralmagic.com Co-authored-by: rshaw@neuralmagic.com Co-authored-by: Nick Hill --- tests/v1/core/test_scheduler.py | 4 +- tests/v1/engine/conftest.py | 90 +++ tests/v1/engine/test_async_llm.py | 49 +- tests/v1/engine/test_llm_engine.py | 23 + tests/v1/engine/test_output_processor.py | 553 +++++++++++++++--- tests/v1/engine/utils.py | 382 ++++++++++++ tests/v1/entrypoints/__init__.py | 0 tests/v1/entrypoints/conftest.py | 161 +++++ .../v1/entrypoints/openai/test_completion.py | 475 +++++++++++++++ tests/v1/sample/test_logprobs.py | 392 +++++++++++++ tests/v1/sample/test_logprobs_e2e.py | 52 ++ tests/v1/sample/utils.py | 120 ++++ vllm/outputs.py | 10 +- vllm/transformers_utils/detokenizer_utils.py | 19 + vllm/v1/core/scheduler.py | 43 +- vllm/v1/engine/__init__.py | 10 +- vllm/v1/engine/core.py | 5 +- vllm/v1/engine/core_client.py | 5 +- vllm/v1/engine/detokenizer.py | 54 +- vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/logprobs.py | 194 ++++++ vllm/v1/engine/output_processor.py | 126 ++-- vllm/v1/engine/processor.py | 39 +- vllm/v1/metrics/stats.py | 19 +- vllm/v1/outputs.py | 54 +- vllm/v1/sample/metadata.py | 3 +- vllm/v1/sample/sampler.py | 94 ++- vllm/v1/serial_utils.py | 50 +- vllm/v1/worker/gpu_input_batch.py | 27 +- vllm/v1/worker/gpu_model_runner.py | 102 +++- 30 files changed, 2869 insertions(+), 287 deletions(-) create mode 100644 tests/v1/engine/conftest.py create mode 100644 tests/v1/engine/test_llm_engine.py create mode 100644 tests/v1/engine/utils.py create mode 100644 tests/v1/entrypoints/__init__.py create mode 100644 tests/v1/entrypoints/conftest.py create mode 100644 tests/v1/entrypoints/openai/test_completion.py create mode 100644 tests/v1/sample/test_logprobs.py create mode 100644 tests/v1/sample/test_logprobs_e2e.py create mode 100644 tests/v1/sample/utils.py create mode 100644 vllm/v1/engine/logprobs.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8eb08f3e8..0d29729a4 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -195,8 +195,8 @@ def test_schedule_partial_requests(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[0] * len(requests), - logprob_token_ids_cpu=None, - logprobs_cpu=None, + logprobs=None, + prompt_logprobs_dict={}, ) scheduler.update_from_output(output, model_runner_output) diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py new file mode 100644 index 000000000..560dc3121 --- /dev/null +++ b/tests/v1/engine/conftest.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Tuple + +import pytest +import torch +from transformers import AutoTokenizer + +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, + TOKENIZER_NAME, + DummyOutputProcessorTestVectors, + generate_dummy_prompt_logprobs_tensors, + generate_dummy_sample_logprobs) +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +from tests.v1.engine.utils import FULL_STRINGS # isort: skip + +EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]] +EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor] + + +def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, without logprobs + + Returns: + DummyOutputProcessorTestVectors instance with no logprobs + """ + + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() + # Tokenize prompts under test & create dummy generated tokens + prompt_tokens = [ + tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS + ] + generation_tokens = [ + tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS + ] + # Generate prompt strings + prompt_strings = [ + tokenizer.decode(prompt_tokens, skip_special_tokens=True) + for prompt_tokens in prompt_tokens + ] + prompt_strings_len = [ + len(prompt_string) for prompt_string in prompt_strings + ] + return DummyOutputProcessorTestVectors( + tokenizer=tokenizer, + tokenizer_group=init_tokenizer_from_configs( + vllm_config.model_config, vllm_config.scheduler_config, + vllm_config.parallel_config, vllm_config.lora_config), + vllm_config=vllm_config, + full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], + prompt_tokens=prompt_tokens, + generation_tokens=generation_tokens, + prompt_strings=prompt_strings, + prompt_strings_len=prompt_strings_len, + generation_strings=[ + text[prompt_len:] + for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) + ], + prompt_logprobs=[], + generation_logprobs=[]) + + +@pytest.fixture +def dummy_test_vectors() -> DummyOutputProcessorTestVectors: + """Generate output processor dummy test vectors, with logprobs + + Returns: + DummyOutputProcessorTestVectors instance with logprobs + """ + # Build dummy test vectors without logprobs + dtv = _build_test_vectors_no_logprobs() + # Inject logprobs into dummy test vectors + # data structure + dtv.generation_logprobs = [ + generate_dummy_sample_logprobs( + sampled_tokens_list=tokens_list, + num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens + ] + dtv.prompt_logprobs = [ + generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list=tokens_list, + num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, + tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens + ] + return dtv diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 4b5bc9ced..94e18289e 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -2,10 +2,11 @@ import asyncio from contextlib import ExitStack -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.platforms import current_platform @@ -21,13 +22,19 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", disable_log_requests=True) -async def generate(engine: AsyncLLM, request_id: str, +async def generate(engine: AsyncLLM, + request_id: str, output_kind: RequestOutputKind, - max_tokens: int) -> Tuple[int, str]: + max_tokens: int, + prompt_logprobs: Optional[int] = None) -> Tuple[int, str]: + # Ensure generate doesn't complete too fast for cancellation test. + await asyncio.sleep(0.2) + count = 0 sampling_params = SamplingParams(max_tokens=max_tokens, output_kind=output_kind, - temperature=0) + temperature=0, + prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt="Hello my name is Robert and", sampling_params=sampling_params): @@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str, return count, request_id +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_async_llm_refuses_prompt_logprobs_with_apc( + monkeypatch, output_kind: RequestOutputKind): + """Test passes if AsyncLLM raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + # TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a + # better way to test V1 so that in the future when we switch, we don't + # have to change all the tests. + monkeypatch.setenv("VLLM_USE_V1", "1") + # Create AsyncLLM engine with APC + apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.8, + disable_log_requests=True) + engine = AsyncLLM.from_engine_args(apc_engine_args) + try: + with pytest.raises(ValueError) as excinfo: + # Issue a request with prompt logprobs enabled, which should fail + await asyncio.create_task( + generate(engine, + "request-0", + output_kind, + 10, + prompt_logprobs=5)) + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG + finally: + # Shut down engine + engine.shutdown() + + @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.asyncio diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py new file mode 100644 index 000000000..84b634316 --- /dev/null +++ b/tests/v1/engine/test_llm_engine.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG +from vllm import LLM, SamplingParams + + +def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch): + """Test passes if LLMEngine raises an exception when it is configured + for automatic prefix caching and it receives a request with + prompt_logprobs enabled, which is incompatible.""" + + monkeypatch.setenv("VLLM_USE_V1", "1") + # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + with pytest.raises(ValueError) as excinfo: + LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate( + "Hello, my name is", + SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) + + # Validate exception string is correct + assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 5782a249f..c8f43edb7 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -1,82 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List +import math +from typing import Dict, List, Optional import pytest -from transformers import AutoTokenizer -from vllm.engine.arg_utils import EngineArgs +from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + STOP_STRINGS, + DummyOutputProcessorTestVectors, + MockEngineCore) from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.sequence import PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor -TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" -VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config() -TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config, - VLLM_CONFIG.scheduler_config, - VLLM_CONFIG.parallel_config, - VLLM_CONFIG.lora_config) -tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) - -FULL_STRINGS = [ - "My name is Robert from Neural Magic and I love working on vLLM so much!", - "Red Hat is the best open source company by far across Linux, K8s, and AI.", - "Nick is the name of my brother in addition to my colleague from Red Hat.", -] - -STOP_STRINGS = ["I love working on", "company by far", "brother in"] - -FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS] -PROMPT_LEN = 5 -PROMPT_TOKENS = [ - tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS -] -GENERATION_TOKENS = [ - tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS -] -PROMPT_STRINGS = [ - tokenizer.decode(prompt_tokens, skip_special_tokens=True) - for prompt_tokens in PROMPT_TOKENS -] -PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS] -GENERATION_STRINGS = [ - text[prompt_len:] - for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN) -] - - -class MockEngineCore: - """Mock outputs form premade tokens lists.""" - - def __init__(self, tokens_list: List[List[int]]): - self.tokens_list = tokens_list - self.current_idx = 0 - - def get_outputs(self) -> List[EngineCoreOutput]: - token_idx = self.current_idx - self.current_idx += 1 - - outputs = [] - for req_idx, token_ids in enumerate(self.tokens_list): - if len(token_ids) > token_idx: - output = EngineCoreOutput(request_id=f"request-{req_idx}", - new_token_ids=[token_ids[token_idx]], - finished=False) - if token_idx == len(token_ids) - 1: - output.finished = True - output.finish_reason = "stopped" - outputs.append(output) - - return outputs + +def _ref_convert_id_to_token( + tokenizer: AnyTokenizer, + token_id: int, +) -> str: + """Reference impl of logprobs detokenization. + + Args: + tokenizer: tokenizer used by the model under test + token_id: convert this token id + + Returns: + String representation of input token id + """ + return tokenizer.convert_ids_to_tokens(token_id) or "" @pytest.mark.parametrize( "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -def test_incremental_detokenization(request_output_kind: RequestOutputKind): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_incremental_detokenization(request_output_kind: RequestOutputKind, + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -94,10 +59,10 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): spaces_between_special_tokens=False, output_kind=request_output_kind, stop=[], - include_stop_str_in_output=False)) - for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + include_stop_str_in_output=False, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -113,7 +78,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): break # Step the Detokenizer. - processed_outputs = output_processor.process_outputs(outputs, ) + processed_outputs = output_processor.process_outputs(outputs) request_outputs = processed_outputs.request_outputs requests_to_abort = processed_outputs.reqs_to_abort assert len(requests_to_abort) == 0 @@ -132,7 +97,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( - zip(GENERATION_STRINGS, GENERATION_TOKENS)): + zip(dummy_test_vectors.generation_strings, + dummy_test_vectors.generation_tokens)): gen_str = gen_strings[f"request-{idx}"] gen_toks = gen_tokens[f"request-{idx}"] @@ -143,15 +109,390 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): assert not output_processor.has_unfinished_requests() +def _validate_logprobs( + gen_tokens: Dict[str, List[int]], + gen_logprobs: Dict[str, Optional[SampleLogprobs]], + gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]], + gen_cumulative_logprob: Dict[str, float], + dtv: DummyOutputProcessorTestVectors, + request_id_list: List[str], + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + for req_idx, req_id in enumerate(request_id_list): + new_tokens = gen_tokens[req_id] + logprobs = gen_logprobs[req_id] + prompt_logprobs = gen_prompt_logprobs[req_id] + cumulative_logprob = gen_cumulative_logprob[req_id] + prompt_token_ids = dtv.prompt_tokens[req_idx] + ref_logprobs = dtv.generation_logprobs[req_idx] + ref_prompt_logprobs = dtv.prompt_logprobs[req_idx] + if num_sample_logprobs is not None: + # Validate sample logprobs + assert logprobs is not None, (f"Request {req_id} requires sample" + " logprobs but sample logprobs are" + " None.") + # Require num sampled tokens to match num + # sampled logprobs - especially important + # to check since the detokenizer can cause + # a request to finish early due to a stop + # string being hit + num_new_tokens = len(new_tokens) + len_sample_logprobs = len(logprobs) + assert num_new_tokens == len_sample_logprobs, ( + f"Request {req_id} has {num_new_tokens}" + " completion tokens but has" + f" {len_sample_logprobs} sample logprobs.") + ref_cumulative_logprob = 0.0 + for idx, (sampled_token, + pos_logprob_dict) in enumerate(zip(new_tokens, + logprobs)): + # Break out the reference log probability value & + # logprob token id tensors associated with this + # position in the completion. Also break out the + # sampled token ranks + (ref_pos_logprob_toks, ref_pos_logprob_vals, + ref_sampled_token_rank) = ref_logprobs[idx] + # For each position in the completion sequence, + # ensure the actual sampled token is among the + # logprobs + assert sampled_token in pos_logprob_dict, ( + f"Sampled token {sampled_token} not" + f" present in logprob at index {idx}") + + # Validate number of sample logprobs + num_lp_toks = len(pos_logprob_dict) + assert (num_lp_toks == num_sample_logprobs + or num_lp_toks == num_sample_logprobs + + 1), ("Valid numbers of sample logprobs are" + f" {num_sample_logprobs} or" + f" {num_sample_logprobs+1} but" + f" {num_lp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate sampled token logprob rank + smp_lp = pos_logprob_dict[sampled_token] + smp_lp_rank = smp_lp.rank + assert (ref_sampled_token_rank == smp_lp_rank), ( + "Sampled token logprob rank" + f" {smp_lp_rank} does not match" + " correct value" + f" {ref_sampled_token_rank}" + f" in Logprob {smp_lp}") + + # Validate that the logprob processor yields + # the correct log probabilities and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_lp_val = ref_pos_logprob_vals[jdx] + ref_tok_id = ref_pos_logprob_toks[jdx] + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + lp = pos_logprob_dict[ref_tok_id] + lp_val = lp.logprob + lp_rank = lp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if lp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert lp_rank >= 1, (f"Logprob {lp} has invalid" + f" rank {lp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(lp_val, ref_lp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {lp_val} but {ref_lp_val} was" + f" expected. Logprob: {lp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate logprobs detokenization + for lp_tok in pos_logprob_dict: + # Confirm that sample logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[lp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, lp_tok) + assert decoded_token == ref_decoded_token, ( + f"Sampled logprob token id {lp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + + ref_cumulative_logprob += pos_logprob_dict[ + sampled_token].logprob + # Assert that cumulative logprobs are correct + assert math.isclose(cumulative_logprob, ref_cumulative_logprob) + else: + # Sample logprobs disabled for this request + assert logprobs is None + assert cumulative_logprob is None + + if num_prompt_logprobs is not None: + # Validate prompt logprobs + assert prompt_logprobs is not None, ( + f"Request {req_id} requires prompt" + " logprobs but prompt logprobs are" + " None.") + # Require num prompt tokens to match num + # prompt logprobs + num_prompt_tokens = len(prompt_token_ids) + len_prompt_logprobs = len(prompt_logprobs) + assert num_prompt_tokens == len_prompt_logprobs, ( + f"Request {req_id} has {num_prompt_tokens}" + " prompt tokens but has" + f" {len_prompt_logprobs} prompt logprobs.") + # First prompt logprob is None + first_plp_dict = prompt_logprobs[0] + assert first_plp_dict is None, ( + f"Request {req_id} first prompt logprob" + f" should be None but has following value" + f" instead: {first_plp_dict}") + # Break out the reference prompt log prob value & + # logprob token id matrices for the whole prompt. + # Also break out the prompt token rank vector + (ref_prompt_logprob_toks, ref_prompt_logprob_vals, + ref_prompt_token_ranks) = ref_prompt_logprobs + for idx, (prompt_token, pos_logprob_dict) in enumerate( + zip(prompt_token_ids[1:], prompt_logprobs[1:])): + + # Break out the reference prompt log prob value + # vector, prompt logprob token id vector, and + # prompt token rank at the current position. + (ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals, + ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :], + ref_prompt_logprob_vals[idx, :], + ref_prompt_token_ranks[idx]) + + # For each position in the prompt sequence, + # ensure the actual prompt token is among the + # logprobs + assert prompt_token in pos_logprob_dict, ( + f"Prompt token {prompt_token} not" + f" present in logprob at index {idx}") + # Validate number of prompt logprobs + num_plp_toks = len(pos_logprob_dict) + assert (num_plp_toks == num_prompt_logprobs + or num_plp_toks == num_prompt_logprobs + + 1), ("Valid numbers of prompt logprobs are" + f" {num_prompt_logprobs} or" + f" {num_prompt_logprobs+1} but" + f" {num_plp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}") + + # Validate prompt token logprob rank + prmpt_tok_lp = pos_logprob_dict[prompt_token] + prmpt_tok_lp_rank = prmpt_tok_lp.rank + ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank + assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), ( + "Prompt token logprob rank" + f" {prmpt_tok_lp_rank} does not match" + " correct value" + f" {ref_prmpt_tok_lp_rank}" + f" in Logprob {prmpt_tok_lp}") + + # Validate that the logprob processor yields + # the correct prompt log probs and valid + # rankings + rank_one_appears = False + for jdx in range(1, len(ref_pos_prompt_logprob_toks)): + # Iterate over the (logprob val,logprob tok id) + # pairs expected by the test fixture at this + # position in the completion. + ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx]) + ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx]) + assert ref_tok_id in pos_logprob_dict, ( + f"Expected token {ref_tok_id} to be" + f" in logprob dict but it is not.") + + # Extract actually-generated logprob + # info + plp = pos_logprob_dict[ref_tok_id] + plp_val = plp.logprob + plp_rank = plp.rank + + # A "top" (rank 1) logprob must be + # present + rank_one_appears = (True + if plp_rank == 1 else rank_one_appears) + + # Rank must be >= 1 + assert plp_rank >= 1, ( + f"Logprob {plp} has invalid" + f" rank {plp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}") + + # Validate log probability + assert math.isclose(plp_val, ref_plp_val), ( + f"Token id {ref_tok_id} appears in logprobs dict" + f" at position {idx} in completion with log" + f" probability {plp_val} but {ref_plp_val} was" + f" expected. Logprob: {plp}") + + assert rank_one_appears, (f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}") + + # Validate prompt logprob detokenization + for plp_tok in pos_logprob_dict: + # Confirm that prompt logprob decoded token matches + # the logprob token id at this sequence position + decoded_token = pos_logprob_dict[plp_tok].decoded_token + ref_decoded_token = _ref_convert_id_to_token( + dtv.tokenizer, plp_tok) + assert decoded_token == ref_decoded_token, ( + f"Prompt logprob token id {plp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})") + else: + # Prompt logprobs disabled for this request + assert prompt_logprobs is None + + +@pytest.mark.parametrize( + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor(request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=None if num_sample_logprobs is None else + dummy_test_vectors.generation_logprobs, + prompt_logprobs_raw=None + if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs) + + # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] + requests = [ + EngineCoreRequest(request_id=request_id_list[idx], + prompt=prompt, + prompt_token_ids=prompt_tokens, + arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) + ] + + # Add requests to the detokenizer. + for request in requests: + output_processor.add_request(request) + + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the logprobs processor. + processed_outputs = output_processor.process_outputs(outputs) + request_outputs = processed_outputs.request_outputs + requests_to_abort = processed_outputs.reqs_to_abort + assert len(requests_to_abort) == 0 + + # Update tracking. + for request_output in request_outputs: + request_id = request_output.request_id + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob + if request_id not in gen_logprobs: + # Start tracking sample and prompt logprobs for this request + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs + else: + # Extend logprobs tracker + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) + + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() + + @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -def test_stop_string(include_stop_str_in_output: bool): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) - engine_core = MockEngineCore(GENERATION_TOKENS) +@pytest.mark.parametrize("num_sample_logprobs", + [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", + [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_stop_string(include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=False) + engine_core = MockEngineCore( + tokens_list=dummy_test_vectors.generation_tokens, + generated_logprobs_raw=dummy_test_vectors.generation_logprobs + if num_sample_logprobs else None, + prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs + if num_prompt_logprobs else None) # Make N requests. + request_id_list = [ + f"request-{idx}" + for idx in range(len(dummy_test_vectors.prompt_strings)) + ] requests = [ EngineCoreRequest( - request_id=f"request-{idx}", + request_id=request_id_list[idx], prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, @@ -166,9 +507,11 @@ def test_stop_string(include_stop_str_in_output: bool): output_kind=RequestOutputKind.DELTA, stop=STOP_STRINGS, include_stop_str_in_output=include_stop_str_in_output, - )) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + )) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add requests to the detokenizer. @@ -176,6 +519,10 @@ def test_stop_string(include_stop_str_in_output: bool): output_processor.add_request(request) gen_strings = {} + gen_tokens = {} + gen_logprobs = {} + gen_prompt_logprobs = {} + gen_cumulative_logprobs = {} aborted = [] while True: # Mock output from the EngineCore. @@ -199,14 +546,29 @@ def test_stop_string(include_stop_str_in_output: bool): request_id = request_output.request_id new_text = request_output.outputs[0].text + new_tokens = request_output.outputs[0].token_ids + prompt_logprobs = request_output.prompt_logprobs + logprobs = request_output.outputs[0].logprobs + gen_cumulative_logprobs[request_id] = request_output.outputs[ + 0].cumulative_logprob if request_id not in gen_strings: gen_strings[request_id] = new_text + gen_tokens[request_id] = new_tokens + gen_logprobs[request_id] = logprobs + gen_prompt_logprobs[request_id] = prompt_logprobs else: gen_strings[request_id] += new_text + gen_tokens[request_id].extend(new_tokens) + lp = gen_logprobs[request_id] + plp = gen_prompt_logprobs[request_id] + if lp: + lp.extend(logprobs) + if plp: + plp.extend(prompt_logprobs) # Confirmed tracked values matches what we expected. - for idx, (ref_gen_str, - stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)): + for idx, (ref_gen_str, stop_str) in enumerate( + zip(dummy_test_vectors.generation_strings, STOP_STRINGS)): # Request should be aborted. request_id = f"request-{idx}" @@ -227,13 +589,20 @@ def test_stop_string(include_stop_str_in_output: bool): assert gen_str == ref_str_exc_stop, ( f"{gen_str=}, {ref_str_exc_stop=}") + # Confirmed tracked logprobs match what we expect + _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, + gen_cumulative_logprobs, dummy_test_vectors, + request_id_list, num_sample_logprobs, + num_prompt_logprobs) + assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() -def test_iteration_stats(): - output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True) - engine_core = MockEngineCore(GENERATION_TOKENS) +def test_iteration_stats(dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + log_stats=True) + engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) # Make N requests. requests = [ @@ -248,13 +617,13 @@ def test_iteration_stats(): eos_token_id=None, lora_request=None, sampling_params=SamplingParams(), - ) for idx, ( - prompt, - prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ) for idx, (prompt, prompt_tokens) in enumerate( + zip(dummy_test_vectors.prompt_strings, + dummy_test_vectors.prompt_tokens)) ] # Add all requests except one to the OutputProcessor. - num_active = len(GENERATION_TOKENS) - 1 + num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: output_processor.add_request(request) inactive_request = requests[num_active] @@ -263,8 +632,10 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = sum( - [len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]]) + total_prompt_tokens = sum([ + len(prompt_tokens) + for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] + ]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active @@ -283,7 +654,7 @@ def test_iteration_stats(): outputs = engine_core.get_outputs()[:num_active] processed_outputs = output_processor.process_outputs(outputs) iteration_stats = processed_outputs.iteration_stats - total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1]) + total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py new file mode 100644 index 000000000..39248ce86 --- /dev/null +++ b/tests/v1/engine/utils.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.engine.arg_utils import EngineArgs +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) +from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +# Number of sample logprobs to request when testing sample logprobs +NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5 +# Number of prompt logprobs to request when testing prompt logprobs +NUM_PROMPT_LOGPROBS_UNDER_TEST = 7 + +TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" + +FULL_STRINGS = [ + "My name is Robert from Neural Magic and I love working on vLLM so much!", + "Red Hat is the best open source company by far across Linux, K8s, and AI.", + "Nick is the name of my brother in addition to my colleague from Red Hat.", +] +STOP_STRINGS = ["I love working on", "company by far", "brother in"] +PROMPT_LEN = 5 + +PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + +random.seed(42) + + +def _create_random_top_logprob_test_vector( + num_logprobs: int, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random vector of top logprob float values. + + Use to create fake sample logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order, something + which is omitted in this function. + + Args: + num_logprobs: number of top logprobs + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 1D length-`num_logprobs` torch Tensor of float logprob values + """ + return torch.rand(num_logprobs) * (upper - lower) + lower + + +def _create_random_top_logprob_test_matrix( + shape: Tuple, + lower: float, + upper: float, +) -> torch.Tensor: + """Create a random matrix of top logprob float values. + + Use to create fake prompt logprobs for testing. + + Note that a real production scenario would require + logprobs to be sorted in descending order along rows, + something which is omitted in this function. + + Args: + shape: (num_tokens,num_logprobs) tuple representing + matrix shape + lower: lower range of logprob float values + upper: upper range of logprob float values + + Returns: + 2D num_tokens x num_logprobs torch Tensor of float logprob values + """ + return torch.rand(*shape) * (upper - lower) + lower + + +def _create_random_top_token_test_vector( + num_logprobs: int, + lower: int, + upper: int, + sampled_token_id: int, + adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]: + """Create a random vector of top logprob token indices + + Use to create fake sample logprobs for testing. The sampled token + ID must always be one of the top logprobs, which this dummy test + vector generator enforces. OpenAI API + compatible engines must be able to return an additional sample + logprob for the sampled token if the sampled token was not + among the top sample logprobs; `adjust_num_logprobs` emulates + this behavior by increasing the vector length by 1 if + `adjust_num_logprobs` is set. + + Args: + num_logprobs: number of top logprobs + lower: lower range of token ids + upper: upper range of token ids + sampled_token_id: the token actually sampled + adjust_num_logprobs: if True, emulate situation where sampled + token logprob must be injected into top + logprobs + + Returns: + 1D length-x torch Tensor of token ids where x is + `num_logprobs+1` if `adjust_num_logprobs` and + `num_logprobs` otherwise + sampled_token_rank: the rank of sampled_token_id in the vocab + vector when sorted in descending order by + logprob + """ + + # Calculate the final number of logprobs required + total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs + + # Generate random indices using torch + choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower + + # Ensure the sampled token ID is included in the tensor + choice_tensor[0] = sampled_token_id + + # Check if the sampled_token_id occurs in choice_tensor[1:] + if sampled_token_id in choice_tensor[1:]: + sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero( + as_tuple=True)[0].item() + else: + # If not found, assign a random int between num_logprobs and 50700 + sampled_token_rank = random.randint(num_logprobs, 50700) + + return choice_tensor, sampled_token_rank + + +def _create_random_top_token_test_matrix( + shape: Tuple[int, int], + lower: int, + upper: int, + tokens_list: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create a random matrix of top logprob token indices + + Use to create fake prompt logprobs for testing. + + Token ids are generated randomly and sampled without + replacement. + + Args: + shape: (num_tokens, num_logprobs) tuple representing + matrix shape + lower: lower range of token ids + upper: upper range of token ids + + Returns: + Tuple containing: + - 2D num_tokens x num_logprobs+1 torch Tensor of token ids + - 1D tensor of ranks of prompt tokens in their respective + rows, or random values + """ + num_elements = shape[0] * shape[1] + choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower + matrix = torch.cat( + (torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), + choice_tensor.view(shape)), + dim=1) + + # Initialize the tensor for storing the ranks + prompt_token_ranks = torch.empty(shape[0], dtype=torch.int) + + # Iterate over each row to check presence of + # tokens_list[rdx] and determine its index + for rdx in range(shape[0]): + row = matrix[rdx, + 1:] # Skip the first column as it contains the token list + token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0] + if token_index.numel() > 0: + prompt_token_ranks[rdx] = token_index.item() + else: + prompt_token_ranks[rdx] = random.randint(shape[1], 50700) + + return matrix, prompt_token_ranks + + +def decode_token( + tok_id: int, + tokenizer: PreTrainedTokenizer, +) -> str: + """Reproduce the process of detokenizing a token for testing purposes. + + Args: + tok_id: token id to detokenize + tokenizer: tokenizer to use for detokenization + + Returns: + string representation of token + """ + return tokenizer.convert_ids_to_tokens(tok_id) + + +def generate_dummy_sample_logprobs( + sampled_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> List[Tuple[List[int], List[float], int]]: + """Generate dummy sample logprobs + + Generate a test data structure which imitates the list of sample logprobs + which would be assembled in the engine core during decode phase. + + Args: + sampled_tokens_list: list of sampled tokens + num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + List of (top token ids vector, logprobs vector, sampled token rank) + Python lists tuples; in each tuple the logprobs and top token ids + vectors have the same length which is either `num_logprobs` or + `num_logprobs+1`. Sampled token rank is the rank (index+1) of the + sampled token within the vocab vector when sorted by logprob in + descending order. + """ + res = [] + for sampled_token_id in sampled_tokens_list: + ( + token_vector, + sampled_token_rank, + ) = _create_random_top_token_test_vector(num_logprobs, 0, + len(tokenizer.vocab) - 1, + sampled_token_id) + + res.append( + (token_vector, + _create_random_top_logprob_test_vector(num_logprobs + 1, -100, + 0), sampled_token_rank)) + + # Convert tensors in the list tuples to Python lists + res_list_format = [ + (log_probs_tensor.tolist(), token_ids_tensor.tolist(), + sampled_token_rank) + for log_probs_tensor, token_ids_tensor, sampled_token_rank in res + ] + + return res_list_format + + +def generate_dummy_prompt_logprobs_tensors( + prompt_tokens_list: List, + num_logprobs: int, + tokenizer: PreTrainedTokenizer, +) -> LogprobsTensors: + """Generate dummy prompt logprobs tensors + + Generate a test data structure which imitates the torch Tensors of prompt + logprobs which would be assembled in the engine core during chunked + prefill. + + Args: + prompt_tokens_list: list of prompt tokens + num_logprobs: return `num_logprobs` logprobs per token + tokenizer: model tokenizer to use for detokenization + + Returns + Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor, + where both matrices have dimensions + num_prompt_tokens x num_logprobs + """ + # For now, assume the whole prompt is processed in one chunk; thus, + # the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`. + # Prior to injecting `None` at the beginning of prompt logprobs (which + # happens later in the detokenizer, not here), the prompt logprobs in + # the ith position are predicting the probability distribution of the + # prompt token in (i+1)st position. Thus, we concat + # `prompt_tokens_list[1:]` to the dummy token ids, just as the engine + # would. + num_prompt_logprobs = len(prompt_tokens_list) - 1 + ( + token_vector, + prompt_token_ranks, + ) = _create_random_top_token_test_matrix( + (num_prompt_logprobs, num_logprobs), 0, + len(tokenizer.vocab) - 1, prompt_tokens_list[1:]) + return LogprobsTensors( + token_vector, + _create_random_top_logprob_test_matrix( + (num_prompt_logprobs, num_logprobs + 1), -100, 0), + prompt_token_ranks) + + +@dataclass +class DummyOutputProcessorTestVectors: + """Dummy test vectors for output processor tests""" + tokenizer: GeneralTokenizerType + tokenizer_group: BaseTokenizerGroup + vllm_config: EngineArgs + full_tokens: List[List[int]] # Prompt + generated tokens + prompt_tokens: List[List[int]] + generation_tokens: List[List[int]] + # Each request is associated with a tuple of + # (top tokens, top logprobs, ranks) prompt logprobs tensors + prompt_logprobs: List[LogprobsTensors] + # Each request is associated with a sample logprobs; a request's + # sample logprobs are a list of (top tokens, top logprobs, ranks) + # sample logprobs tensors at each sequence position + generation_logprobs: List[List[Tuple[List[int], List[float], int]]] + prompt_strings: List[str] + prompt_strings_len: List[int] + generation_strings: List[str] + + +class MockEngineCore: + """Mock engine core outputs form premade tokens lists.""" + + def __init__( + self, + tokens_list: List[List[int]], + # For each request, for each sampled token offset, + # a tuple of + # (list of topk token ids, list of sample logprob vals, rank) + generated_logprobs_raw: Optional[List[List[Tuple[List[int], + List[float], + int]]]] = None, + # For each request, a tuple of + # (prompt logprob val matrix, prompt logprob tok id matrix); + # each matrix has dimensions + # (num prompt toks) x (num prompt logprobs+1) + prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None, + ) -> None: + self.tokens_list = tokens_list + self.current_idx = 0 + self.generated_logprobs_raw = generated_logprobs_raw + self.do_logprobs = generated_logprobs_raw is not None + self.prompt_logprobs_raw = prompt_logprobs_raw + self.do_prompt_logprobs = prompt_logprobs_raw is not None + + def get_outputs(self) -> List[EngineCoreOutput]: + do_logprobs = self.do_logprobs + do_prompt_logprobs = self.do_prompt_logprobs + token_idx = self.current_idx + + outputs = [] + for req_idx, token_ids in enumerate(self.tokens_list): + if len(token_ids) > token_idx: + if do_logprobs: + assert self.generated_logprobs_raw is not None + (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( + self.generated_logprobs_raw[req_idx][token_idx]) + logprobs = LogprobsLists( + [logprobs_token_ids_], + [logprobs_], + [sampled_token_ranks_], + ) + else: + logprobs = None + if do_prompt_logprobs: + if self.current_idx == 0: + assert self.prompt_logprobs_raw is not None + prompt_logprobs = self.prompt_logprobs_raw[req_idx] + else: + prompt_logprobs = None + else: + prompt_logprobs = None + output = EngineCoreOutput( + request_id=f"request-{req_idx}", + new_token_ids=[token_ids[token_idx]], + new_logprobs=logprobs, + new_prompt_logprobs_tensors=prompt_logprobs, + ) + if token_idx == len(token_ids) - 1: + output.finish_reason = FinishReason.STOP + outputs.append(output) + + self.current_idx += 1 + return outputs diff --git a/tests/v1/entrypoints/__init__.py b/tests/v1/entrypoints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py new file mode 100644 index 000000000..b00e168db --- /dev/null +++ b/tests/v1/entrypoints/conftest.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + + +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } + + +@pytest.fixture +def sample_complex_json_schema(): + return { + "type": "object", + "properties": { + "score": { + "type": "integer", + "minimum": 0, + "maximum": 100 # Numeric range + }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, + "tags": { + "type": "array", + "items": { + "type": "string", + "pattern": + "^[a-z]{1,10}$" # Combining length and pattern restrictions + } + } + }, + "required": ["score", "grade", "email", "tags"] + } + + +@pytest.fixture +def sample_definition_json_schema(): + return { + '$defs': { + 'Step': { + 'properties': { + 'explanation': { + 'title': 'Explanation', + 'type': 'string' + }, + 'output': { + 'title': 'Output', + 'type': 'string' + } + }, + 'required': ['explanation', 'output'], + 'title': 'Step', + 'type': 'object' + } + }, + 'properties': { + 'steps': { + 'items': { + '$ref': '#/$defs/Step' + }, + 'title': 'Steps', + 'type': 'array' + }, + 'final_answer': { + 'title': 'Final Answer', + 'type': 'string' + } + }, + 'required': ['steps', 'final_answer'], + 'title': 'MathReasoning', + 'type': 'object' + } + + +@pytest.fixture +def sample_guided_choice(): + return [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", + "Ruby", "Swift", "Kotlin" + ] + + +@pytest.fixture +def sample_sql_statements(): + return (""" +start: select_statement +select_statement: "SELECT" column "from" table "where" condition +column: "col_1" | "col_2" +table: "table_1" | "table_2" +condition: column "=" number +number: "1" | "2" +""") diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py new file mode 100644 index 000000000..ef46a16ef --- /dev/null +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Dict, List, Optional + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +from openai import BadRequestError + +from tests.utils import RemoteOpenAIServer +from vllm.transformers_utils.tokenizer import get_tokenizer + +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager" + ] + + +@pytest.fixture(scope="module", + params=[["--no-enable-prefix-caching"], + [ + "--no-enable-prefix-caching", + "--disable-frontend-multiprocessing" + ]]) +def server(default_server_args, request): + if request.param: + default_server_args.extend(request.param) + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str) -> None: + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), + (MODEL_NAME, 0), + (MODEL_NAME, 1), + (MODEL_NAME, None)]) +async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "prompt": ["A robot may not injure another robot", "My name is"], + "model": model_name, + } + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.completions.create(**params) + else: + completion = await client.completions.create(**params) + if prompt_logprobs is not None: + assert completion.choices[0].prompt_logprobs is not None + assert len(completion.choices[0].prompt_logprobs) > 0 + + assert completion.choices[1].prompt_logprobs is not None + assert len(completion.choices[1].prompt_logprobs) > 0 + + else: + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str) -> None: + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py new file mode 100644 index 000000000..86c576cd7 --- /dev/null +++ b/tests/v1/sample/test_logprobs.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import List, Tuple + +import pytest +import torch + +from tests.kernels.utils import override_backend_env_variable +from tests.v1.sample.utils import ( + assert_incr_detok_str_matches_non_incr_detok_str, + compute_correct_cumulative_logprob, get_test_batch) +from vllm import SamplingParams + +from ...conftest import VllmRunner + +MODEL = "meta-llama/Llama-3.2-1B" +DTYPE = "half" + + +@pytest.fixture(scope="module") +def vllm_model(vllm_runner): + with vllm_runner( + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + #TODO: enable this once we support it for + # prompt logprobs. + enable_prefix_caching=False, + gpu_memory_utilization=0.5, + ) as vllm_model: + yield vllm_model + + +@pytest.fixture(scope="module") +def hf_model(hf_runner): + with hf_runner(MODEL, dtype=DTYPE) as hf_model: + yield hf_model + + +def _repeat_logprob_config( + test_prompts, + logprob_prompt_logprob_list: List[Tuple], +) -> List[Tuple]: + """Ensure each test prompt has a logprob config. + + A logprob config specifies the optional (i.e. + may-be-`None`) number of sample logprobs and + the optional number of prompt logprobs. + + If more test prompts than logprob configs are + provided, the provided logprob configs are + tiled to match the number of test prompts. + + If fewer test prompts than logprob configs + are provided, the list of logprob configs + is truncated to match the number of test + prompts. + + Otherwise, the list of logprob configs + is returned as-is. + + Args: + test_prompts: list of prompts under test + logprob_prompt_logprob_list: list of + (optional num sample logprob, + optional num prompt logprob) + tuples + + Returns: + List of + (optional num sample logprob,optional num prompt logprob) + tuples which is either identical to + `logprob_prompt_logprob_list`, or else repeats + `logprob_prompt_logprob_list` enough times to match the + number of `test_prompts`, or else is truncated to match + the number of `test_prompts` + """ + num_test_prompts = len(test_prompts) + # Make sure there is a logprobs configuration for each test prompt + logprob_prompt_logprob_list = list( + itertools.islice(itertools.cycle(logprob_prompt_logprob_list), + num_test_prompts)) + # Now the number of prompts should match the number of sample params combos + assert num_test_prompts == len(logprob_prompt_logprob_list) + return logprob_prompt_logprob_list + + +def _test_case_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + test_prompts = example_prompts + + max_tokens = 5 + hf_outputs = hf_model.generate_greedy( + test_prompts, + max_tokens=max_tokens, + ) + hf_logprobs = hf_model.generate_greedy_logprobs( + test_prompts, + max_tokens=max_tokens, + ) + + # Batch has mixed sample params + # (different logprobs/prompt logprobs combos) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) + + # Ensure that each test prompt has a logprob config for testing + logprob_prompt_logprob_list = _repeat_logprob_config( + test_prompts, logprob_prompt_logprob_list) + # Generate SamplingParams + vllm_sampling_params = [ + SamplingParams(max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984) + for num_lp, num_plp in logprob_prompt_logprob_list + ] + + vllm_results = vllm_model.model.generate( + test_prompts, sampling_params=vllm_sampling_params) + + for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( + vllm_results, hf_logprobs, hf_outputs, + logprob_prompt_logprob_list): + + # Extract request-level (prompt)logprobs config + num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob + + # Test whether sampled token output is consistent between vLLM and HF + # vLLM prompt+completion should match HF output + if temperature == 0.0: + assert (vllm_result.prompt_token_ids + + vllm_result.outputs[0].token_ids == hf_output[0]) + else: + # Sampled tokens won't match if not greedy + assert (vllm_result.prompt_token_ids == hf_output[0] + [:len(vllm_result.prompt_token_ids)]) + + # Validate sample logprobs + if num_top_logprobs is not None: + assert num_top_logprobs is not None + # Confirm that the structure of the sample logprobs in the result is + # correct + assert vllm_result.outputs[0].logprobs is not None + assert len(vllm_result.outputs[0].logprobs) == max_tokens + for logprobs, token_id in zip(vllm_result.outputs[0].logprobs, + vllm_result.outputs[0].token_ids): + assert logprobs is not None + + # Confirm that the output token appears among the logprobs + assert token_id in logprobs + token_in_topk = logprobs[token_id].rank <= num_top_logprobs + + # If the output token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_logprobs != 0: + assert len(logprobs) == num_top_logprobs + else: + assert len(logprobs) == num_top_logprobs + 1 + + if num_top_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_logprobs + 1)) + + output_text = vllm_result.outputs[0].text + output_string_from_most_likely_tokens_lst: List[str] = [] + for top_logprobs in vllm_result.outputs[0].logprobs: + top_logprob = next(iter(top_logprobs.values())) + output_string_from_most_likely_tokens_lst.append( + top_logprob.decoded_token) + + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens_lst) + assert_incr_detok_str_matches_non_incr_detok_str( + output_text, output_string_from_most_likely_tokens, + "The output text from the top logprob for each token " + "position should be the same as the output text in the " + "result.") + + # Compare vLLM sample logprobs to HF + vllm_sample_logprobs = vllm_result.outputs[0].logprobs + for i, top_logprobs in enumerate(vllm_sample_logprobs): + for token_id, sample_logprob in top_logprobs.items(): + if temperature == 0.0 or i == 0: + logprob = sample_logprob.logprob + torch.testing.assert_close( + logprob, + hf_logprob[i][-1][token_id].item(), + atol=1e-2, + rtol=1e-2) + assert isinstance( + sample_logprob.decoded_token, + str), ("The token should be decoded by the time it is" + " returned to the user.") + + # At this point we know the sample logprobs are correct for this + # request. Validate that cumulative_logprob is actually the sum. + # For each request, assert that the returned cumulative logprob + # matches the correct value, which is computed below. + torch.testing.assert_close( + vllm_result.outputs[0].cumulative_logprob, + compute_correct_cumulative_logprob(vllm_result.outputs[0]), + atol=1e-6, + rtol=1e-6) + else: + # Logprobs disabled for this request; should be None + assert vllm_result.outputs[0].logprobs is None + + # Validate prompt logprobs + if num_top_prompt_logprobs is not None: + # Confirm that structure of prompt logprobs in result is correct + assert vllm_result.prompt_logprobs is not None + # - The first prompt logprob is always None + assert vllm_result.prompt_logprobs[0] is None + # - Prompt logprobs are returned for all indices in + # the prompt + assert len(vllm_result.prompt_logprobs) == len( + vllm_result.prompt_token_ids) + for prompt_logprobs, prompt_token_id in zip( + vllm_result.prompt_logprobs[1:], + vllm_result.prompt_token_ids[1:]): + assert prompt_logprobs is not None + + # Confirm that the prompt token appears among the logprobs + assert prompt_token_id in prompt_logprobs + token_in_topk = prompt_logprobs[ + prompt_token_id].rank <= num_top_prompt_logprobs + + # If the prompt token is not included in the top K + # logprob, it can return 1 more data + if token_in_topk and num_top_prompt_logprobs != 0: + assert len(prompt_logprobs) == num_top_prompt_logprobs + else: + assert len(prompt_logprobs) == num_top_prompt_logprobs + 1 + + if num_top_prompt_logprobs > 0: + # We should have an entry for each of the topk ranks + all_ranks = {lp.rank for lp in prompt_logprobs.values()} + assert all(r in all_ranks + for r in range(1, num_top_prompt_logprobs + 1)) + + # Compare prompt logprobs to HF + # The first prompt logprob is always None, so we compare it from + # 1:. + vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] + for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): + for token_id, logprob in vllm_prompt_logprob_dict.items(): + torch.testing.assert_close( + logprob.logprob, + hf_logprob[0][i][token_id].item(), + atol=2e-2, + rtol=2e-2) + else: + assert vllm_result.prompt_logprobs is None + + +#@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("batch_logprobs_composition", + ["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) +@pytest.mark.parametrize("temperature", [0.0, 2.0]) +def test_get_logprobs_and_prompt_logprobs( + hf_model, + vllm_model, + batch_logprobs_composition: str, + temperature: float, + example_prompts, +) -> None: + """Test V1 Engine logprobs & prompt logprobs + + Exercise a variety of combinations of `logprobs` and `prompt_logprobs` + settings and validate that + * The generated logprobs and prompt logprobs are consistent with the + configuration settings, in terms of whether or not the logprobs + (of either type) were requested and how many were requested + * The generated logprobs are consistent with the generated tokens + * The generated (prompt)logprobs are consistent with HuggingFace + (prompt)logprobs, as a reference + + batch_logprobs_composition controls the logprobs configurations for + requests in the batch under test. + + Args: + hf_model + vllm_model + batch_logprobs_composition: logprobs configuration for test batch + example_prompts + monkeypatch + """ + _test_case_get_logprobs_and_prompt_logprobs( + hf_model=hf_model, + vllm_model=vllm_model, + batch_logprobs_composition=batch_logprobs_composition, + temperature=temperature, + example_prompts=example_prompts) + + +def test_max_logprobs(monkeypatch): + """vLLM v1 engine should fail a request with `logprobs > max_logprobs` + + Should also fail for `prompt_logprobs > max_logprobs` + + Args: + monkeypatch + """ + override_backend_env_variable(monkeypatch, "FLASH_ATTN") + + runner = VllmRunner("facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + max_model_len=256) + vllm_sampling_params = SamplingParams(logprobs=1) + # should pass + runner.generate(["Hello world"], sampling_params=vllm_sampling_params) + + bad_sampling_params = SamplingParams(logprobs=2) + with pytest.raises(ValueError): + runner.generate(["Hello world"], sampling_params=bad_sampling_params) + + +def test_none_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return `logprobs` and `prompt_logprobs` as `None` + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + prompt_logprobs=None, + temperature=0.0) + results_logprobs_none = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_none) + + for i in range(len(results_logprobs_none)): + # Check sample logprobs are None + assert results_logprobs_none[i].outputs[0].logprobs is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None + # Check prompt logprobs are None + assert results_logprobs_none[i].prompt_logprobs is None + + +def test_zero_logprobs(vllm_model, example_prompts, monkeypatch): + """Engine should return sampled token and prompt token logprobs + + Args: + vllm_model: vLLM model fixture + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test + """ + max_tokens = 5 + + sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, + logprobs=0, + prompt_logprobs=0, + temperature=0.0) + results_logprobs_zero = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params_logprobs_zero) + + for i in range(len(results_logprobs_zero)): + # Check that there is one sample logprob dict for each + # sample token + logprobs = results_logprobs_zero[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_zero[i].prompt_logprobs + sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids + prompt_token_ids = results_logprobs_zero[i].prompt_token_ids + assert logprobs is not None + assert len(sampled_token_ids) == len(logprobs) + assert results_logprobs_zero[i].outputs[ + 0].cumulative_logprob is not None + # Check that there is one prompt logprob dict for each + # prompt token + assert prompt_logprobs is not None + assert len(prompt_token_ids) == len(prompt_logprobs) diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py new file mode 100644 index 000000000..28c177fd4 --- /dev/null +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 + +import lm_eval + +from ...utils import RemoteOpenAIServer + +# arc-easy uses prompt_logprobs=1, logprobs=1 +TASK = "arc_easy" +FILTER = "acc_norm,none" +RTOL = 0.03 +EXPECTED_VALUE = 0.62 + +# FIXME(rob): enable prefix caching once supported. +MODEL = "meta-llama/Llama-3.2-1B" +MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 +SERVER_ARGS = [ + "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" +] +NUM_CONCURRENT = 100 + + +def test_prompt_logprobs_e2e(): + results = lm_eval.simple_evaluate(model="vllm", + model_args=MODEL_ARGS, + tasks=TASK, + batch_size="auto") + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +def test_promt_logprobs_e2e_server(): + with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py new file mode 100644 index 000000000..e1465b123 --- /dev/null +++ b/tests/v1/sample/utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import List, Tuple + +from vllm import CompletionOutput + + +def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]: + """Generate logprobs configs for a batch of requests + + A given request's logprobs configuration is (1) num_sample_logprobs and (2) + num_prompt_logprobs. The batch logprobs configuration is the list of request + logprobs configs. + + batch_logprobs_composition == "NONE" yields a batch with no sample or prompt + logprobs + + batch_logprobs_composition == "SAMPLE" yields a batch with some requests + configured for sample logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "PROMPT" yields a batch with some requests + configured for prompt logprobs only, and others configured for no logprobs + + batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some + requests configured for sample logprobs and prompt logprobs, some configured + for only sample logprobs or only prompt logprobs, and some configured for + no logprobs + + Args: + batch_logprobs_composition: types of logprobs configs to include in batch + + Returns: + + List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) + tuples + """ + if batch_logprobs_composition == "NONE": + # No requests with sample or prompt logprobs + return [(None, None)] + elif batch_logprobs_composition == "SAMPLE": + # Requests requiring sample logprobs or no logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + ] + elif batch_logprobs_composition == "PROMPT": + # Requests requiring prompt logprobs or no logprobs + return [ + (None, None), + (None, 0), + (None, 6), + (None, 5), + ] + elif batch_logprobs_composition == "SAMPLE_PROMPT": + # Requests requiring either no logprobs, just + # sample logprobs, just prompt logprobs, or + # both sample and prompt logprobs + return [ + (None, None), + (0, None), + (5, None), + (3, None), + (0, 3), + (6, 0), + (6, 3), + (None, 6), + (None, 5), + (None, 0), + ] + else: + raise ValueError("Invalid logprobs batch configuration for test.") + + +def assert_incr_detok_str_matches_non_incr_detok_str( + incremental_detokenization_str: str, + non_incremental_detokenization_str: str, + msg: str, +) -> None: + """Compare incrementally detok. text to non-incrementally detok. text + + Fail if the strings mismatch after non-alphanumeric characters are stripped + out. + + Rationale: incremental detokenization in the text generation process allows + the tokenizer to adjust the next token text output based on the token's + context in the string. However, logprobs detokenization detokenizes each + token individually, and the resultant strings may include some + non-alphanumeric placeholder characters where there could be i.e. + whitespace. So, this function compares only the alphanumeric text + between two strings and fails if there is a mismatch, which helps + with validating logprobs detokenization. + + Args: + incremental_detokenization_str: incrementally-detokenized generated text + non_incremental_detokenization_str: non-incrementally-detokenized logprob + tokens + msg: error message if `assert` fails + """ + rgx = r'[^a-zA-Z0-9]+' + assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub( + rgx, '', non_incremental_detokenization_str)), (msg) + + +def compute_correct_cumulative_logprob( + completion_output: CompletionOutput) -> float: + """Compute known-good value for evaluating cumulative logprob + + Args: + completion_output: completion output from engine + + Returns: + Known-good cumulative logprob value + """ + token_ids = completion_output.token_ids + logprobs = completion_output.logprobs + assert logprobs is not None + return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)]) diff --git a/vllm/outputs.py b/vllm/outputs.py index 786380c37..030119710 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -142,6 +142,9 @@ class RequestOutput: prompt_token_ids: Optional[List[int]], text: str, token_ids: List[int], + logprobs: Optional[SampleLogprobs], + prompt_logprobs: Optional[PromptLogprobs], + cumulative_logprob: Optional[float], finished: bool = False, ) -> "RequestOutput": """Initialize a new RequestOutput object.""" @@ -151,15 +154,14 @@ class RequestOutput: index=0, text=text, token_ids=token_ids, - cumulative_logprob=None, - logprobs=None, # TODO - ) + cumulative_logprob=cumulative_logprob, + logprobs=logprobs) return RequestOutput( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, # TODO + prompt_logprobs=prompt_logprobs, outputs=[completion_output], finished=finished, ) diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 8160a35ff..a1fa27773 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -74,6 +74,25 @@ def convert_prompt_ids_to_tokens( return new_tokens, prefix_offset, read_offset +def convert_ids_list_to_tokens( + tokenizer: AnyTokenizer, + token_ids: List[int], +) -> List[str]: + """Detokenize the input ids individually. + + Args: + tokenizer: tokenizer used by model under test + token_ids: convert these tokens (Python list form) + + Returns: + Python list of token string representations + + """ + token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) + _replace_none_with_empty(token_str_lst) # type: ignore + return token_str_lst + + # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 6c44fec64..35d9424f9 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -437,6 +437,8 @@ class Scheduler: ) -> EngineCoreOutputs: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] outputs: List[EngineCoreOutput] = [] @@ -471,6 +473,13 @@ class Scheduler: self.encoder_cache_manager.free_encoder_input( request, input_id) + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + + stopped = False + new_logprobs = None + new_token_ids = None + if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] # NOTE(woosuk): Currently, we assume that each request @@ -486,20 +495,30 @@ class Scheduler: if stopped: self._free_request(request) + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None: + assert logprobs is not None + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + new_token_ids = request.output_token_ids[-num_new_tokens:] + + # Transmit partial if chunked prefill & prompt logprobs is enabled + if new_token_ids or prompt_logprobs_tensors is not None: # Add EngineCoreOutput for this Request. - output = EngineCoreOutput( - request_id=req_id, - new_token_ids=request.output_token_ids[-num_new_tokens:], - finished=request.is_finished(), - finish_reason=request.get_finished_reason(), - stop_reason=request.stop_reason) - outputs.append(output) - - # Breakout of the loop. - if stopped: - continue + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids or [], + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason)) + + if not stopped: + new_running.append(request) - new_running.append(request) self.running = new_running return EngineCoreOutputs( outputs=outputs, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d5933cac5..b05ef3cc8 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Union import msgspec from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import LogprobsLists, LogprobsTensors if TYPE_CHECKING: from vllm.lora.request import LoRARequest @@ -67,10 +68,17 @@ class EngineCoreOutput( request_id: str new_token_ids: List[int] - finished: bool + + new_logprobs: Optional[LogprobsLists] = None + new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None + @property + def finished(self) -> bool: + return self.finish_reason is not None + class EngineCoreOutputs( msgspec.Struct, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 29a9ac186..f3d40aa1e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,7 +11,6 @@ from typing import List, Tuple, Type import psutil import zmq import zmq.asyncio -from msgspec import msgpack from vllm.config import VllmConfig from vllm.logger import init_logger @@ -26,7 +25,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -292,7 +291,7 @@ class EngineCoreProc(EngineCore): """Output socket IO thread.""" # Msgpack serialization encoding. - encoder = msgpack.Encoder() + encoder = MsgpackEncoder() # Reuse send buffer. buffer = bytearray() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 247380ef7..cdc63acdb 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -7,7 +7,6 @@ import weakref from abc import ABC, abstractmethod from typing import List, Optional, Type -import msgspec import zmq import zmq.asyncio @@ -20,7 +19,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequestUnion, EngineCoreResetPrefixCache) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -163,7 +162,7 @@ class MPClient(EngineCoreClient): # Serialization setup. self.encoder = PickleEncoder() - self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. self.ctx = ( diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 861fcb012..629da06f4 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,27 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -@dataclass -class DetokenizerOutput: - output_text: str - token_ids: List[int] - finished: bool - finish_reason: Optional[FinishReason] = None - stop_reason: Union[int, str, None] = None - - @dataclass class IncrementalDetokenizer: @@ -42,7 +32,6 @@ class IncrementalDetokenizer: # Parameters for detokenization skip_special_tokens: bool spaces_between_special_tokens: bool - output_kind: RequestOutputKind # Tokenizer for this request tokenizer: AnyTokenizer @@ -90,25 +79,19 @@ class IncrementalDetokenizer: skip_special_tokens=request.sampling_params.skip_special_tokens, spaces_between_special_tokens=request.sampling_params. spaces_between_special_tokens, - output_kind=request.sampling_params.output_kind, prompt_len=len(request.prompt_token_ids), tokenizer=tokenizer, stop_buffer_length=stop_buffer_length, ) - def update_from_output( - self, - output: EngineCoreOutput, - ) -> Optional[DetokenizerOutput]: + def update(self, new_token_ids: List[int]) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. - 2) Update the RequestOutput with the new text. - """ + 2) Evaluate stop criteria. - new_token_ids = output.new_token_ids - finish_reason = output.finish_reason - stop_reason = output.stop_reason + Return matched stop string or None. + """ # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of @@ -131,11 +114,13 @@ class IncrementalDetokenizer: self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset self.read_offset = read_offset - self.output_text += new_decoded_token_text decoded_text += new_decoded_token_text + self.output_text += decoded_text + # 2) Evaluate stop criteria. + stop_string = None if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, @@ -144,28 +129,13 @@ class IncrementalDetokenizer: include_in_output=self.include_stop_str_in_output, ) if stop is not None: - stop_str, truncate_to = stop + stop_string, truncate_to = stop if truncate_to != -1: self.output_text = self.output_text[:truncate_to] - finish_reason = FinishReason.STOP - stop_reason = stop_str - - # TODO: handle stop_token_ids here too? - - # 3) Update the RequestOutput object with the new text. - finished = finish_reason is not None - if self.output_kind == RequestOutputKind.FINAL_ONLY \ - and not finished: - return None - - delta = self.output_kind == RequestOutputKind.DELTA - output_text = self._get_next_output_text(finished, delta) - token_ids = new_token_ids if delta else self.output_token_ids - return DetokenizerOutput(output_text, token_ids, finished, - finish_reason, stop_reason) + return stop_string - def _get_next_output_text(self, finished: bool, delta: bool) -> str: + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e0452bcad..3ef5a9706 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -45,6 +45,7 @@ class LLMEngine: multiprocess_mode: bool = False, ) -> None: self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py new file mode 100644 index 000000000..4622cafa4 --- /dev/null +++ b/vllm/v1/engine/logprobs.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from dataclasses import dataclass +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_ids_list_to_tokens) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.outputs import LogprobsLists, LogprobsTensors + +logger = init_logger(__name__) + + +@dataclass +class LogprobsProcessor: + + # Tokenizer for this request + tokenizer: AnyTokenizer + + # Logprobs for this request + logprobs: Optional[SampleLogprobs] + prompt_logprobs: Optional[PromptLogprobs] + cumulative_logprob: Optional[float] + num_logprobs: Optional[int] + num_prompt_logprobs: Optional[int] + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + ) -> "LogprobsProcessor": + num_logprobs = request.sampling_params.logprobs + num_prompt_logprobs = request.sampling_params.prompt_logprobs + return cls( + tokenizer=tokenizer, + cumulative_logprob=(None if num_logprobs is None else 0.), + logprobs=(None if num_logprobs is None else []), + # NOTE: logprob of first prompt token is None. + prompt_logprobs=(None if num_prompt_logprobs is None else [None]), + num_prompt_logprobs=num_prompt_logprobs, + num_logprobs=num_logprobs, + ) + + def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: + """Update with sample logprobs from EngineCore. + + Outer lists are only of len > 1 if EngineCore made + >1 tokens in prior step (e.g. in spec decoding). + + Args: + logprobs_lists: the lists of logprob tokens, logprobs, and ranks. + + """ + + assert self.num_logprobs is not None + assert self.logprobs is not None + assert self.cumulative_logprob is not None + + token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, + token_ids_lst): + + # Detokenize (non-incrementally). + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, token_ids) + + # Sampler puts the sampled logprob in first. + sampled_token_logprob = logprobs[0] + self.cumulative_logprob += sampled_token_logprob + + # Update with the Logprob dictionary for this pos. + self.logprobs.append( + self._make_logprob_dict( + logprobs, + token_ids, + decoded_tokens, + rank, + self.num_logprobs, + )) + + def _update_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + ) -> None: + """Update with prompt logprobs from EngineCore. + + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + + """ + + # Prompt logprobs are enabled. + assert self.num_prompt_logprobs is not None + assert self.prompt_logprobs is not None + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = convert_ids_list_to_tokens( + self.tokenizer, + token_ids.flatten().tolist()) + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the torch tensors. + # TODO(rob): experiment with doing this in EngineCore? + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + self.prompt_logprobs.append( + self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs)) + + def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: + """Pop and return all request prompt logprobs + + The logprobs processor aggregates prompt chunk logprobs + over one or more prefill chunks. This method returns + all prompt logprobs at once and then forgets them. + Ensures correct RequestOutputKind.DELTA semantics + wherein all prompt logprobs are returned at once at + the end of prefill. + + Returns: + None if prompt logprobs are disabled for this request. + List of all prompt logprobs, otherwise. + """ + plp = self.prompt_logprobs + if plp: + self.prompt_logprobs = [] + return plp + + @staticmethod + def _make_logprob_dict( + logprobs: List[float], + logprob_token_ids: List[int], + decoded_tokens: List[str], + rank: int, + num_logprobs: int, + ) -> Dict[int, Logprob]: + """Make a Logprob dictionary for a position. + + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + + Returns: + Dict[token id, Logprob] + """ + + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank, ), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip( + logprob_token_ids, logprobs, ranks, decoded_tokens) + } + + def update_from_output(self, output: EngineCoreOutput) -> None: + if output.new_logprobs is not None: + self._update_sample_logprobs(output.new_logprobs) + if output.new_prompt_logprobs_tensors is not None: + self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 947366691..5dbf530ca 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from typing import Dict, List, Optional from vllm.outputs import RequestOutput -from vllm.transformers_utils.detokenizer_utils import AnyTokenizer +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest -from vllm.v1.engine.detokenizer import (DetokenizerOutput, - IncrementalDetokenizer) +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.metrics.stats import IterationStats, RequestStateStats @@ -26,16 +27,20 @@ class RequestState: def __init__( self, request_id: str, + output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: List[int], + logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], ): self.request_id = request_id + self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.prompt_len = len(prompt_token_ids) + self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.is_prefilling = True self.queue = queue @@ -51,8 +56,13 @@ class RequestState: ) -> "RequestState": return cls( request_id=request.request_id, + output_kind=request.sampling_params.output_kind, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, + logprobs_processor=LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ), detokenizer=IncrementalDetokenizer.from_new_request( tokenizer=tokenizer, request=request, @@ -127,13 +137,8 @@ class OutputProcessor: batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. - If you need to touch every element of the batch, implement a - method called XXXClass.update_from_output() to be called - within the loop below. For examples, see: - * IterationStats.update_from_output() - * Detokenizer.update_from_output() - - TODO(rob): add Protocol makes update_from_output explicit. + If you need to touch every element of the batch, do it from + within the loop below. ********************************************************** """ @@ -154,17 +159,37 @@ class OutputProcessor: req_state.is_prefilling, req_state.prompt_len, req_state.stats) - req_state.is_prefilling = False - - # 2) Detokenize the token ids into text. - detokenizer_output = req_state.detokenizer.update_from_output( - engine_core_output) - - # 3) Create and handle RequestOutput objects. - if detokenizer_output is not None: - request_output = self._make_request_output( - req_state, detokenizer_output) + new_token_ids = engine_core_output.new_token_ids + finish_reason = engine_core_output.finish_reason + + # TODO(andy): prompt logprobs + chunked prefill can + # result in engine core returning an output for a + # partial prefill (in order to send back partial + # prompt logprobs.) This breaks the invariant that + # process_outputs is only operating on engine core + # outputs associated with non-partial completions. + # Currently this is handled by having `is_prefilling` + # check for new decoded tokens, indicating that + # the completion is not partial. + # + # Follow up will aggregate partial prompt logprobs + # in the EngineCore. + req_state.is_prefilling = not new_token_ids + + # 2) Detokenize the token ids into text and check for stop + # strings. + stop_reason = req_state.detokenizer.update(new_token_ids) + if stop_reason: + finish_reason = FinishReason.STOP + + # 3) Compute sample and prompt logprobs for request, + # if required. + req_state.logprobs_processor.update_from_output(engine_core_output) + + # 4) Create and handle RequestOutput objects. + if request_output := self._make_request_output( + req_state, new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -174,18 +199,16 @@ class OutputProcessor: # Free completed requests. if request_output.finished: - assert detokenizer_output.finish_reason is not None - self.request_states.pop(req_id) if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - # Track per-request stats + # Track per-request stats. + assert finish_reason is not None iteration_stats.update_from_finished_request( - detokenizer_output.finish_reason, request_output, - req_state.stats) + finish_reason, request_output, req_state.stats) return OutputProcessorOutput( request_outputs=request_outputs, @@ -196,20 +219,47 @@ class OutputProcessor: @staticmethod def _make_request_output( request_state: RequestState, - detokenizer_output: DetokenizerOutput, - ) -> RequestOutput: + new_token_ids: List[int], + finish_reason: Optional[FinishReason], + stop_reason: Optional[str], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = request_state.output_kind + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (request_state.is_prefilling + or output_kind == RequestOutputKind.FINAL_ONLY): + # Only the final output is required in FINAL_ONLY mode. + return None + + detokenizer = request_state.detokenizer + logprobs_processor = request_state.logprobs_processor + + delta = output_kind == RequestOutputKind.DELTA + logprobs = logprobs_processor.logprobs + if delta: + if logprobs: + logprobs = logprobs[-len(new_token_ids):] + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = logprobs_processor.prompt_logprobs + request_output = RequestOutput.new( - request_state.request_id, - request_state.prompt, - request_state.prompt_token_ids, - detokenizer_output.output_text, - detokenizer_output.token_ids, - detokenizer_output.finished, + request_id=request_state.request_id, + prompt=request_state.prompt, + prompt_token_ids=request_state.prompt_token_ids, + text=detokenizer.get_next_output_text(finished, delta), + token_ids=new_token_ids if delta else detokenizer.output_token_ids, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + cumulative_logprob=logprobs_processor.cumulative_logprob, + finished=finished, ) - if detokenizer_output.finished: + if finished: completion_output = request_output.outputs[0] - completion_output.finish_reason = str( - detokenizer_output.finish_reason) - completion_output.stop_reason = detokenizer_output.stop_reason + completion_output.finish_reason = str(finish_reason) + completion_output.stop_reason = stop_reason return request_output diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 366287951..70876b03a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -33,6 +33,7 @@ class Processor: ): self.model_config = model_config + self.cache_config = cache_config self.lora_config = lora_config self.tokenizer = tokenizer @@ -51,6 +52,37 @@ class Processor: self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ cache_config.enable_prefix_caching + def _validate_logprobs( + self, + params: Union[SamplingParams, PoolingParams], + ) -> None: + if not isinstance(params, SamplingParams): + return + + max_logprobs = self.model_config.max_logprobs + # Validate sample logprobs. + if params.logprobs and params.logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {params.logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # Validate prompt logprobs. + if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {params.prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}") + + # TODO(andy): enable this in follow up by recomputing. + if (params.prompt_logprobs is not None + and self.cache_config.enable_prefix_caching): + raise ValueError("Prefix caching with prompt logprobs not yet " + "supported on VLLM V1.") + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + def process_inputs( self, request_id: str, @@ -64,12 +96,11 @@ class Processor: ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Check max_logprobs # TODO(woosuk): Support encoder-decoder models. - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + self._validate_logprobs(params) + self._validate_lora(lora_request) + if arrival_time is None: arrival_time = time.time() assert priority == 0, "vLLM V1 does not support priority at the moment." diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index e3f1efcc9..5e588d35e 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -60,14 +60,17 @@ class IterationStats: self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - # This relies on the invariant that EngineCore does - # not stream outputs for partially completed prefills - # (scheduler.update_from_output makes EngineCoreOutput - # iff num_computed_tokens == num_tokens). - assert (num_new_generation_tokens > 0) - self.num_prompt_tokens += prompt_len - - self.time_to_first_tokens_iter.append(last_token_latency) + # TODO(andy): we used to assert that num_new_generation_tokens + # > 0 with an invariant that EngineCore does not stream outputs + # for partially completed prefills (scheduler.update_from_output + # makes EngineCoreOutput iff num_computed_tokens == num_tokens). + # When prompt logprobs are enabled, we currently stream out the + # partially completed prompt. + # This will be reverted in a follow up PR and we should re-enable + # this assertion / invariant. + if num_new_generation_tokens > 0: + self.num_prompt_tokens += prompt_len + self.time_to_first_tokens_iter.append(last_token_latency) else: self.time_per_output_tokens_iter.append(last_token_latency) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 6e82bffd7..27fd2dbda 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,25 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, NamedTuple, Optional import torch -@dataclass -class SamplerOutput: +class LogprobsLists(NamedTuple): + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: List[List[int]] + # [num_reqs, max_num_logprobs + 1] + logprobs: List[List[float]] # [num_reqs] - sampled_token_ids: torch.Tensor + sampled_token_ranks: List[int] + + def slice(self, start: int, end: int): + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + + +class LogprobsTensors(NamedTuple): # [num_reqs, max_num_logprobs + 1] - logprob_token_ids: Optional[torch.Tensor] + logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] - logprobs: Optional[torch.Tensor] + logprobs: torch.Tensor + # [num_reqs] + selected_token_ranks: torch.Tensor - # TODO: Support prompt logprobs. - prompt_logprob_token_ids: Optional[torch.Tensor] - prompt_logprobs: Optional[torch.Tensor] + def tolists(self): + return LogprobsLists( + self.logprob_token_ids.tolist(), + self.logprobs.tolist(), + self.selected_token_ranks.tolist(), + ) + + +@dataclass +class SamplerOutput: + + # [num_reqs] + sampled_token_ids: torch.Tensor + logprobs_tensors: Optional[LogprobsTensors] # ModelRunnerOutput is serialized and sent to the scheduler process. @@ -36,6 +62,12 @@ class ModelRunnerOutput: sampled_token_ids: List[int] # [num_reqs, max_num_logprobs + 1] - logprob_token_ids_cpu: Optional[torch.Tensor] # [num_reqs, max_num_logprobs + 1] - logprobs_cpu: Optional[torch.Tensor] + # [num_reqs] + logprobs: Optional[LogprobsLists] + + # req_id -> (token_ids, logprobs, ranks) + # [prompt_len, num_prompt_logprobs] + # [prompt_len, num_prompt_logprobs] + # [prompt_len] + prompt_logprobs_dict: Dict[str, LogprobsTensors] diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 8e54de345..1a2771bab 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -20,7 +20,8 @@ class SamplingMetadata: generators: Dict[int, torch.Generator] - max_num_logprobs: int + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: Optional[int] no_penalties: bool prompt_token_ids: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 3da7498e0..43fd64aaa 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" -from typing import Tuple import torch import torch.nn as nn -from vllm.v1.outputs import SamplerOutput +from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) @@ -25,20 +24,16 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - needs_logprobs = sampling_metadata.max_num_logprobs > 0 - if needs_logprobs: - # NOTE(woosuk): Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. - # This is different from the V0 sampler, which uses the logits that - # is used for sampling (after penalties and temperature scaling). - # NOTE: We compute logprobs first because the below ops may - # modify the logits tensor in-place (and we don't want to clone - # the logits tensor for memory efficiency). - topk_logprobs, topk_indices = self.get_topk_logprobs( - logits, sampling_metadata) - else: - topk_logprobs = None - topk_indices = None + + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # TODO(rob): provide option for logprobs post sampling. + # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + raw_logprobs = self.compute_logprobs(logits) # Use float32 for the logits. logits = logits.to(torch.float32) @@ -48,15 +43,19 @@ class Sampler(nn.Module): logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) + + # Gather the logprobs of the topk and sampled token (if requested). + # Get logprobs and rank tensors (if requested) + logprobs_tensors = None if num_logprobs is None else \ + self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) + # These are GPU tensors. sampler_output = SamplerOutput( sampled_token_ids=sampled, - logprob_token_ids=topk_indices, - logprobs=topk_logprobs, - prompt_logprob_token_ids=None, - prompt_logprobs=None, + logprobs_tensors=logprobs_tensors, ) return sampler_output @@ -103,19 +102,52 @@ class Sampler(nn.Module): ) return sampled - def get_topk_logprobs( + def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + return logits.log_softmax(dim=-1, dtype=torch.float32) + + def gather_logprobs( self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Tuple[torch.Tensor, torch.Tensor]: - logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) - # FIXME: Mask the sampled token_id, get topk logprobs, - # and concatenate the topk with the sampled token_id. - topk_logprobs, topk_indices = torch.topk( - logprobs, sampling_metadata.max_num_logprobs, dim=-1) + logprobs: torch.Tensor, + num_logprobs: int, + token_ids: torch.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logits: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + # Find the topK values. + topk_logprobs, topk_indices = torch.topk(logprobs, + num_logprobs, + dim=-1) + + # Get with the logprob of the prompt or sampled token. + token_ids = token_ids.unsqueeze(-1) + token_logprobs = logprobs.gather(-1, token_ids) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + # Concatenate together with the topk. + indices = torch.cat((token_ids, topk_indices), dim=1) + logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) + # Use int32 to reduce the tensor size. - topk_indices = topk_indices.to(torch.int32) - return topk_logprobs, topk_indices + indices = indices.to(torch.int32) + + return LogprobsTensors(indices, logprobs, token_ranks) def apply_penalties( self, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 1791dfa2b..a7fba65e7 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,12 +1,58 @@ # SPDX-License-Identifier: Apache-2.0 import pickle +from typing import Any + +import torch +from msgspec import msgpack + +CUSTOM_TYPE_CODE_PICKLE = 1 class PickleEncoder: - def encode(self, obj): + def encode(self, obj: Any): return pickle.dumps(obj) - def decode(self, data): + def decode(self, data: Any): return pickle.loads(data) + + +class MsgpackEncoder: + """Encoder with custom torch tensor serialization.""" + + def __init__(self): + self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook) + + def encode(self, obj: Any) -> bytes: + return self.encoder.encode(obj) + + def encode_into(self, obj: Any, buf: bytearray) -> None: + self.encoder.encode_into(obj, buf) + + +class MsgpackDecoder: + """Decoder with custom torch tensor serialization.""" + + def __init__(self, t: Any): + self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook) + + def decode(self, obj: Any): + return self.decoder.decode(obj) + + +def custom_enc_hook(obj: Any) -> Any: + if isinstance(obj, torch.Tensor): + # NOTE(rob): it is fastest to use numpy + pickle + # when serializing torch tensors. + # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 + return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy())) + + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def custom_ext_hook(code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_CODE_PICKLE: + return torch.from_numpy(pickle.loads(data)) + + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a31e88865..d5b8fd218 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -176,7 +176,9 @@ class InputBatch: self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: Dict[str, int] = {} def add_request( self, @@ -238,11 +240,10 @@ class InputBatch: if request.generator is not None: self.generators[req_index] = request.generator - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs # Add request lora ID if request.lora_request: @@ -272,7 +273,7 @@ class InputBatch: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) + self.num_prompt_logprobs.pop(req_id, None) # LoRA lora_id = self.request_lora_mapping[req_index] @@ -297,7 +298,7 @@ class InputBatch: self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() + self.num_prompt_logprobs.clear() self.request_lora_mapping.fill(0) self.lora_id_to_lora_request.clear() self.lora_id_to_request_ids.clear() @@ -489,13 +490,9 @@ class InputBatch: and len(self.repetition_penalties_reqs) == 0) @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None @property def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 + return not self.num_prompt_logprobs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfc9d1ca8..561c3cf39 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:num_scheduled_tokens] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(batch_changed) @@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for i, req_id in enumerate( # type: ignore[assignment] + self.input_batch.req_ids[:num_reqs]): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + @@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. sampled_token_ids = sampler_output.sampled_token_ids.tolist() + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states, + scheduler_output, + ) + # Update with the actual token ids for i, req_state, seq_len in request_seq_lens: token_id = sampled_token_ids[i] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids[-1] = token_id - if sampler_output.logprob_token_ids is None: - logprob_token_ids = None - else: - logprob_token_ids = sampler_output.logprob_token_ids.cpu() - if sampler_output.logprobs is None: - logprobs = None - else: - logprobs = sampler_output.logprobs.cpu() - model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, - logprob_token_ids_cpu=logprob_token_ids, - logprobs_cpu=logprobs, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output @@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin): logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + scheduler_output: "SchedulerOutput", + ) -> Dict[str, LogprobsTensors]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Determine number of logits to retrieve. + start_tok = request.num_computed_tokens + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens < num_remaining_tokens: + # This is a chunk, more tokens remain. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc_np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.model.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer GPU->CPU async. + prompt_logprobs_dict[req_id] = LogprobsTensors( + token_ids.to("cpu", non_blocking=True), + logprobs.to("cpu", non_blocking=True), + ranks.to("cpu", non_blocking=True), + ) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + torch.cuda.synchronize() + + return prompt_logprobs_dict + @torch.inference_mode() def _dummy_run( self, -- GitLab From eaa92d443743830f9efd35320cf6d440e49283e3 Mon Sep 17 00:00:00 2001 From: TJian Date: Sat, 8 Feb 2025 00:13:43 +0800 Subject: [PATCH 022/253] [ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (#12501) --- Dockerfile.rocm_base | 2 +- .../installation/gpu/rocm.inc.md | 64 ++++++--- tests/quantization/test_fp8.py | 49 +++++-- tests/quantization/test_ptpc_fp8.py | 55 ++++++++ .../layers/quantization/__init__.py | 3 + .../layers/quantization/ptpc_fp8.py | 125 ++++++++++++++++++ .../layers/quantization/utils/w8a8_utils.py | 27 ++++ vllm/platforms/rocm.py | 2 +- 8 files changed, 295 insertions(+), 32 deletions(-) create mode 100644 tests/quantization/test_ptpc_fp8.py create mode 100644 vllm/model_executor/layers/quantization/ptpc_fp8.py diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index 5bbe98b0c..e33e73b30 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="8d4926e" +ARG PYTORCH_BRANCH="3a585126" ARG PYTORCH_VISION_BRANCH="v0.19.1" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index c8fd11415..336d578de 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # Installation -vLLM supports AMD GPUs with ROCm 6.2. +vLLM supports AMD GPUs with ROCm 6.3. :::{attention} There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. @@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu ## Requirements - GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -- ROCm 6.2 +- ROCm 6.3 ## Set up using Python @@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. - Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/) + Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: + + ```console + # Install PyTorch + $ pip uninstall torch -y + $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3 + ``` 1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) @@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels. pip uninstall -y triton git clone https://github.com/OpenAI/triton.git cd triton - git checkout e192dba + git checkout e5be006 cd python pip3 install . cd ../.. @@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels. 2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile) - Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) + Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. - For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. + For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```console git clone https://github.com/ROCm/flash-attention.git cd flash-attention - git checkout 3cea2fb + git checkout b7d29fb git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels. You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) ::: -3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps: +3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: ```bash $ pip install --upgrade pip - # Install PyTorch - $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2 - # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi # Install dependencies - $ pip install --upgrade numba scipy huggingface-hub[cli] + $ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm $ pip install "numpy<2" $ pip install -r requirements-rocm.txt @@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels. For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). ::: -## Set up using Docker +## Set up using Docker (Recommended) ### Pre-built images @@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image. Building the Docker image from source is the recommended way to use vLLM with ROCm. -First, build a docker image from and launch a docker container from the image. +#### (Optional) Build an image with ROCm software stack + +Build a docker image from which setup ROCm software stack needed by the vLLM. +**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** +If you choose to build this rocm_base image yourself, the steps are as follows. + It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```console @@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either } ``` - uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: + +```console +DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base . +``` + +#### Build an image with vLLM + +First, build a docker image from and launch a docker container from the image. +It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: + +```console +{ + "features": { + "buildkit": true + } +} +``` + + uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: - `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using @@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running `docker build` with `--build-arg` options. -To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: ```console DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: +To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: ```console DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm . diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 5616935eb..3a7f0a196 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -55,10 +55,21 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): assert isinstance(attn.quant_method, Fp8KVCacheMethod) - # NOTE: it is valid for scales to be 1.0 (default value), but - # we know these checkpoints have scales < 1.0 - assert 0.0 < attn._k_scale < 1.0 - assert 0.0 < attn._v_scale < 1.0 + if not current_platform.is_rocm(): + # NOTE: This code path requires validation on Non-CUDA platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + else: + # NOTE: This code path is for ROCm platform + # NOTE: it is valid for scales to be 1.0 (default value), but + # we know these checkpoints have scales < 1.0 + # However on ROCm platform, the _k_scale and _v_scale will be + # scaled by a factor of 2 as described in + # vllm/model_executor/layers/quantization/kv_cache.py + assert 0.0 < attn._k_scale < (1.0 * 2.0) + assert 0.0 < attn._v_scale < (1.0 * 2.0) llm.apply_model(check_model) @@ -91,13 +102,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - if current_platform.has_device_capability(89) and not force_marlin: - # For GPUs with hardware support, we keep weights in fp8 - assert fc1.weight.dtype == torch.float8_e4m3fn - else: - # For GPUs without hardware support, we pack the fp8 weights - # for weight-only quantization using Marlin kernels - assert fc1.weight.dtype == torch.int32 + if current_platform.is_cuda(): + if current_platform.has_device_capability( + 89) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fn + else: + # For GPUs without hardware support, we pack the fp8 weights + # for weight-only quantization using Marlin kernels + assert fc1.weight.dtype == torch.int32 + elif current_platform.is_rocm(): + # Only MI300 and above support quantization='fp8' + if current_platform.has_device_capability( + 94) and not force_marlin: + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: # unsupported ROCm platform + pytest.skip( + "Skip `test_load_fp16_model`. " + "It only runs on ROCm platform with FP8 compute." + " e.g. MI300X and above.") + else: # unsupported platform + pytest.skip("Skip `test_load_fp16_model`. " + "It only runs on CUDA and ROCm platform.") llm.apply_model(check_model) diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py new file mode 100644 index 000000000..9bbb5e327 --- /dev/null +++ b/tests/quantization/test_ptpc_fp8.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether PTPC w8a8 FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. +""" +import pytest +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod +from vllm.model_executor.layers.quantization.ptpc_fp8 import ( + PTPCFp8LinearMethod) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), + reason="PTPC FP8 is not supported on this GPU type.") +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="This test is for ROCm GPU.") +@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) +def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: + + try: + with vllm_runner("facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype) as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) + if kv_cache_dtype == "ptpc_fp8": + attn = model.model.decoder.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + assert attn._k_scale == 1.0 + assert attn._v_scale == 1.0 + + if current_platform.has_device_capability(94): + # For GPUs with hardware support, we keep weights in fp8 + assert fc1.weight.dtype == torch.float8_e4m3fnuz + else: + pytest.skip() + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + except AssertionError as e: + if str( + e + ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 + # If the error message matches, the test passes + pass + else: + # If the error message does not match, re-raise the exception + raise diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 6ded3874f..6cd508d05 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -11,6 +11,7 @@ QUANTIZATION_METHODS: List[str] = [ "deepspeedfp", "tpu_int8", "fp8", + "ptpc_fp8", "fbgemm_fp8", "modelopt", # The order of gptq methods is important for config.py iteration over @@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .modelopt import ModelOptFp8Config from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig + from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig @@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "ptpc_fp8": PTPCFp8Config, "qqq": QQQConfig, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py new file mode 100644 index 000000000..1ded5389e --- /dev/null +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class PTPCFp8Config(Fp8Config): + """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + if not current_platform.is_rocm(): + raise ValueError( + "ptpc_fp8 quantization is supported only on ROCm.") + + if not current_platform.has_device_capability(94): + raise ValueError( + "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 + ) + if activation_scheme == "static": + raise ValueError( + "ptpc_fp8 as of now only support dynamic quantization.") + + super().__init__(is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + @classmethod + def get_name(cls) -> str: + return "ptpc_fp8" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return PTPCFp8LinearMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class PTPCFp8LinearMethod(Fp8LinearMethod): + """Linear method for Per-Token and Per-Channel FP8 Quantization. + Only supports loading quantized BF16 model checkpoints with dynamic + activation scaling. To load FP16 model checkpoints, user must specify + to convert the FP16 model weight loading into BF16. + The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support float8_e4m3fnuz data type due to the limitation of + torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: PTPCFp8Config): + super().__init__(quant_config=quant_config) + # Force weight quantization + self.quant_config.is_checkpoint_fp8_serialized = False + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + + assert layer.weight.data.dtype == torch.bfloat16, \ + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + # Quantize the weights. + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None, use_per_token_if_dynamic=True) + + # Update the layer with the new values. + layer.weight = Parameter( + qweight.t(), requires_grad=False) # Pretranspose the weight + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias, + cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dedeb0c29..bea6390f7 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,6 +11,13 @@ from vllm.platforms import current_platform # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and current_platform.has_device_capability(94)) + def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): @@ -172,6 +179,26 @@ def apply_fp8_linear( return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + elif (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + else: # Fallback for channelwise case, where we use unfused DQ # due to limitations with scaled_mm diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 035766289..1f690b711 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -72,7 +72,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf", "quark" + "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] @classmethod -- GitLab From 932c6b74616d25199e87c96707e8cfea3ab045c0 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Fri, 7 Feb 2025 18:07:03 -0500 Subject: [PATCH 023/253] [V1] LM Eval With Streaming Integration Tests (#11590) --- .buildkite/test-pipeline.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ef40564c..ab6a576b2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -195,6 +195,9 @@ steps: # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - VLLM_USE_V1=1 pytest -v -s v1/e2e + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" -- GitLab From 45cbc4991dcf405c959f774d07e66e7e9ac71f0c Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:39:50 -0800 Subject: [PATCH 024/253] [Bugfix] Fix disagg hang caused by the prefill and decode communication issues (#12723) Signed-off-by: Lu Fang --- .../kv_lookup_buffer/simple_buffer.py | 87 +++++++++---------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b62352..3462f7de0 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -10,7 +10,6 @@ stop the prefill instance when the decode instance is slow. """ import threading -import time from collections import deque from typing import Deque, List, Optional, Union @@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase): def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float): """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use CPU recv to listen to new request. - + data_pipe: on device (e.g. GPU) """ @@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase): self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh - self.buffer_lock = threading.Lock() + self.buffer_cv = threading.Condition() self.signal_pipe = signal_pipe self.data_pipe = data_pipe self.request_handling_thread: Optional[threading.Thread] = None @@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase): hidden = hidden.clone() buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() - with self.buffer_lock: - for data in buffer_item: - self.buffer_size += self._get_element_size(data) + self.buffer_size += data_size self.buffer.append(buffer_item) + self.buffer_cv.notify() def _is_end_signal(self, signal): return signal is None @@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase): roi = (roi > 0.5) tokens_roi_recver = [input_tokens, roi] - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - + def is_buffer_available( + tokens_roi_recver: List[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True # rotate the element we just accessed to the end self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase): return [input_tokens, roi, key, value, hidden] - def full_handler(self): - time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender -- GitLab From b21f0f9d173e5094791815380cb213f9eed44bad Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 7 Feb 2025 19:07:37 -0800 Subject: [PATCH 025/253] [V1][Minor] Remove outdated comment (#12928) Signed-off-by: Woosuk Kwon --- vllm/v1/core/kv_cache_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index de349ec12..df3dc6c28 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -299,9 +299,7 @@ class KVCacheManager: While all scheduled requests must be in the RUNNING state, the inverse is not necessarily true. There may be RUNNING requests that are not - scheduled in the current step. As of 1/1/2025, the scheduler does not - allow this case, but it is possible in the future, as we allow more - flexible scheduling. + scheduled in the current step. This can result in an edge case where the number of common prefix blocks is 0, even though all scheduled requests share a common prefix. This -- GitLab From 3243158336d377c8aced151722b5f8bbff2f905d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 7 Feb 2025 19:14:10 -0800 Subject: [PATCH 026/253] [V1] Move KV block hashes from Request to KVCacheManager (#12922) Signed-off-by: Woosuk Kwon --- tests/v1/core/test_prefix_caching.py | 21 ++++++++++--------- vllm/v1/core/kv_cache_manager.py | 31 +++++++++++++++++++++------- vllm/v1/core/scheduler.py | 1 + vllm/v1/request.py | 13 ------------ 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a6c0162d3..d598d1257 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -51,7 +51,7 @@ def test_prefill(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(req0.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -76,7 +76,7 @@ def test_prefill(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -107,7 +107,7 @@ def test_prefill(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(req2.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -494,10 +494,11 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks assert num_computed_tokens == 0 - assert len(req0.kv_block_hashes) == 3 - assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) - assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") - assert req0.kv_block_hashes[2].extra_keys == ("bbb", ) + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("aaa", ) + assert block_hashes[1].extra_keys == ("aaa", "bbb") + assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] @@ -510,8 +511,8 @@ def test_mm_prefix_caching(): assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. - assert len(req0.kv_block_hashes) == 4 - assert req0.kv_block_hashes[3].extra_keys == ("ccc", ) + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys == ("ccc", ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -613,7 +614,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(req1.kv_block_hashes) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) assert [b.block_id for b in blocks] == [4] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index df3dc6c28..eefc2e19c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -72,6 +72,12 @@ class KVCacheManager: self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) + # Mapping from request ID to kv block hashes. + # This is to avoid recomputing the block hashes for each call of + # `get_computed_blocks` or `allocate_slots`. + self.req_to_block_hashes: DefaultDict[ + str, List[BlockHashType]] = defaultdict(list) + @property def usage(self) -> float: return 1.0 - (self.free_block_queue.num_free_blocks / @@ -97,11 +103,11 @@ class KVCacheManager: computed_blocks = [] # The block hashes for the request may already be computed - # if the request was preempted and resumed. - if not request.kv_block_hashes: - request.set_kv_block_hashes( - hash_request_tokens(self.block_size, request)) - block_hashes = request.kv_block_hashes + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.block_size, request) + self.req_to_block_hashes[request.request_id] = block_hashes for block_hash in block_hashes: # block_hashes is a chain of block hashes. If a block hash is not @@ -435,7 +441,8 @@ class KVCacheManager: full_blocks: The list of blocks to update hash metadata. prev_block: The previous block in the chain. """ - num_cached_block_hashes = len(request.kv_block_hashes) + block_hashes = self.req_to_block_hashes[request.request_id] + num_cached_block_hashes = len(block_hashes) # Update the new blocks with the block hashes through the chain. prev_block_hash_value = None @@ -468,7 +475,7 @@ class KVCacheManager: # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. - block_hash = request.kv_block_hashes[blk_idx] + block_hash = block_hashes[blk_idx] else: # Otherwise compute the block hash and cache it in the request # in case it will be preempted in the future. @@ -490,9 +497,17 @@ class KVCacheManager: # Compute the hash of the current block. block_hash = hash_block_tokens(prev_block_hash_value, block_tokens, extra_keys) - request.append_kv_block_hashes(block_hash) + block_hashes.append(block_hash) # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk prev_block_hash_value = block_hash.hash_value + + def free_block_hashes(self, request: Request) -> None: + """Discard the block hashes for the request. + + NOTE: Unlike `free`, this method should be called only when the request + is finished, not when it is preempted. + """ + self.req_to_block_hashes.pop(request.request_id, None) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 35d9424f9..1aa34ee38 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -579,6 +579,7 @@ class Scheduler: def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) self._cached_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 89b39ea61..bb4d2c191 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -12,7 +12,6 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange - from vllm.v1.core.kv_cache_utils import BlockHashType class Request: @@ -63,11 +62,6 @@ class Request: if self.mm_hashes: assert len(self.mm_inputs) == len(self.mm_hashes) - # Cache the computed kv block hashes of the request to avoid - # recomputing. - self._kv_block_hashes: List[BlockHashType] = [] - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - # Read-only views # Prevent directly appending to the these lists since # they should also be updated simultaneously. @@ -124,13 +118,6 @@ class Request: num_tokens = self.mm_positions[input_id]["length"] return num_tokens - def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: - self._kv_block_hashes = value - self.kv_block_hashes = ConstantList(self._kv_block_hashes) - - def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: - self._kv_block_hashes.append(block_hash) - class RequestStatus(enum.IntEnum): """Status of a request.""" -- GitLab From 306923da82535593bc508cc4e039bdec55159e9f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 8 Feb 2025 13:02:53 +0800 Subject: [PATCH 027/253] [Bugfix] Fix Qwen2_5_VLForConditionalGeneration packed_modules_mapping (#12905) --- vllm/model_executor/models/qwen2_5_vl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index e93cf46b9..1f350ab20 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -760,9 +760,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "q_proj", "k_proj", "v_proj", - ] + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], } - # LoRA specific attributes, TODO: double check supported_lora_modules = [ "qkv_proj", -- GitLab From cc01223f3ba0434487a0179a2ccd2107bf3c93cb Mon Sep 17 00:00:00 2001 From: Ke Zhao Date: Sat, 8 Feb 2025 14:56:43 +0800 Subject: [PATCH 028/253] [Misc] Fix typo in the example file (#12896) Signed-off-by: Zhao Ke --- .../openai_chat_embedding_client_for_multimodal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py index f49d7a228..e41062037 100644 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_embedding_client_for_multimodal.py @@ -44,7 +44,7 @@ def vlm2vec(): def dse_qwen2_vl(inp: dict): # Embedding an Image - if inp["dtype"] == "image": + if inp["type"] == "image": messages = [{ "role": "user", @@ -113,10 +113,10 @@ if __name__ == '__main__': vlm2vec() elif args.model == "dse_qwen2_vl": dse_qwen2_vl({ - "dtye": "image", + "type": "image", "image_url": image_url, }) dse_qwen2_vl({ - "dtype": "text", + "type": "text", "content": "What is the weather like today?", }) -- GitLab From d01f66b0394e62a11429c8f0afd9a56b7b2b7f0c Mon Sep 17 00:00:00 2001 From: zifeitong Date: Fri, 7 Feb 2025 23:04:34 -0800 Subject: [PATCH 029/253] [Bugfix] Fix multi-round chat error when mistral tokenizer is used (#12859) Signed-off-by: Zifei Tong Co-authored-by: Cyrus Leung --- vllm/transformers_utils/tokenizers/mistral.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 1550f978e..7a1dba424 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -291,6 +291,16 @@ class MistralTokenizer: from mistral_common.protocol.instruct.request import ( ChatCompletionRequest) + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content request = ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) -- GitLab From 91dd8f7aa63a1923cc17868c7646d1277d64ed53 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Feb 2025 16:17:08 +0800 Subject: [PATCH 030/253] [bugfix] respect distributed_executor_backend in world_size=1 (#12934) Signed-off-by: youkaichao --- ...st_custom_executor.py => test_executor.py} | 21 ++++++++- vllm/config.py | 3 ++ vllm/engine/llm_engine.py | 44 +++++++++---------- vllm/v1/executor/abstract.py | 17 ++++--- 4 files changed, 53 insertions(+), 32 deletions(-) rename tests/engine/{test_custom_executor.py => test_executor.py} (79%) diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_executor.py similarity index 79% rename from tests/engine/test_custom_executor.py rename to tests/engine/test_executor.py index 3e77faecb..84cc3ed63 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_executor.py @@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path): engine_args = EngineArgs( model=model, distributed_executor_backend=CustomUniExecutor, + enforce_eager=True, # reduce test time ) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path): assert not os.path.exists(".marker") engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomUniExecutorAsync) + model=model, + distributed_executor_backend=CustomUniExecutorAsync, + enforce_eager=True, # reduce test time + ) engine = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -89,3 +93,18 @@ def test_custom_executor_async(model, tmp_path): assert os.path.exists(".marker") finally: os.chdir(cwd) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_respect_ray(model): + # even for TP=1 and PP=1, + # if users specify ray, we should use ray. + # users might do this if they want to manage the + # resources using ray. + engine_args = EngineArgs( + model=model, + distributed_executor_backend="ray", + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + assert engine.model_executor.uses_ray diff --git a/vllm/config.py b/vllm/config.py index 5579d6936..426ba3808 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1401,6 +1401,9 @@ class ParallelConfig: logger.info("Defaulting to use %s for distributed inference", backend) + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + self._verify_args() @property diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9d..2e5bc75c6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -434,6 +434,7 @@ class LLMEngine: @classmethod def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. @@ -443,30 +444,29 @@ class LLMEngine: "distributed_executor_backend must be a subclass of " f"ExecutorBase. Got {distributed_executor_backend}.") executor_class = distributed_executor_backend - elif engine_config.parallel_config.world_size > 1: - if distributed_executor_backend == "ray": - from vllm.executor.ray_distributed_executor import ( - RayDistributedExecutor) - executor_class = RayDistributedExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.mp_distributed_executor import ( - MultiprocessingDistributedExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingDistributedExecutor - elif distributed_executor_backend == "uni": - # JAX-style, single-process, multi-device executor. - from vllm.executor.uniproc_executor import UniProcExecutor - executor_class = UniProcExecutor - elif distributed_executor_backend == "external_launcher": - # executor with external launcher - from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher) - executor_class = ExecutorWithExternalLauncher - else: + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. from vllm.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") return executor_class @classmethod diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index ac10d43eb..093be09ae 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -25,15 +25,14 @@ class Executor(ExecutorBase): parallel_config = vllm_config.parallel_config distributed_executor_backend = ( parallel_config.distributed_executor_backend) - if distributed_executor_backend is None: - # If the user does not specify the distributed executor backend, - # we will choose the backend based on the world size. - if parallel_config.world_size > 1: - distributed_executor_backend = "mp" - else: - distributed_executor_backend = "uni" - - if distributed_executor_backend == "ray": + # distributed_executor_backend must be set in VllmConfig.__post_init__ + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor -- GitLab From e31498bdcbd70f91786fd2f23f4afabdd4256f1c Mon Sep 17 00:00:00 2001 From: Shaoting Date: Sat, 8 Feb 2025 02:38:20 -0600 Subject: [PATCH 031/253] [Misc] Add offline test for disaggregated prefill (#12418) --- .../disaggregated_prefill.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 examples/offline_inference/disaggregated_prefill.py diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py new file mode 100644 index 000000000..2e41cabac --- /dev/null +++ b/examples/offline_inference/disaggregated_prefill.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of disaggregated prefilling +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), +and then transfer the KV cache between them. +""" +import os +import time +from multiprocessing import Event, Process + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def run_prefill(prefill_done): + # We use GPU 0 for prefill node. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + # The prefill node receives two requests, while the decode node receives + # three requests. So the decode node will only receive the KV Cache for + # requests 1 and 3. The decode node will use the KV Cache of requests 1 + # and 3 and do prefilling on request 2. + prompts = [ + "Hello, my name is", + # "Hi, your name is", + # The decode node will actually "prefill" this request. + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the prefill node (kv_producer, rank 0). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ) + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + llm.generate(prompts, sampling_params) + print("Prefill node is finished.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done; + # otherwise, the script might exit prematurely, causing incomplete decoding. + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Script stopped by user.") + + +def run_decode(prefill_done): + # We use GPU 1 for decode node. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + prompts = [ + "Hello, my name is", + "Hi, your name is", + "Tell me a very long story", + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + # Using PyNcclConnector to transmit KV caches between vLLM instances. + # This instance is the decode node (kv_consumer, rank 1). + # The number of parallel instances for KV cache transfer is set to 2, + # as required for PyNcclConnector. + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ) + + # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB + # memory. You may need to adjust the value to fit your GPU. + llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8) + + # Wait for the producer to start the pipe + print("Waiting for prefill node to finish...") + prefill_done.wait() + + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + prefill_done = Event() + prefill_process = Process(target=run_prefill, args=(prefill_done, )) + decode_process = Process(target=run_decode, args=(prefill_done, )) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Terminate the prefill node when decode is finished + decode_process.join() + prefill_process.terminate() -- GitLab From 4ea48fb35cf67d61a1c3f18e3981c362e1d8e26f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Feb 2025 00:39:09 -0800 Subject: [PATCH 032/253] [V1][Minor] Move cascade attn logic outside _prepare_inputs (#12943) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 150 +++++++++++++++++------------ 1 file changed, 89 insertions(+), 61 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 561c3cf39..e0a096a91 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -476,67 +476,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.device, non_blocking=True).long() # Prepare for cascade attention if needed. - common_prefix_len = (scheduler_output.num_common_prefix_blocks * - self.block_size) - if common_prefix_len == 0: - # Common case. - use_cascade = False - else: - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = FlashAttentionBackend.use_cascade_attention( - common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME - use_sliding_window=self.sliding_window is not None, - num_sms=self.num_sms, - ) - + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + use_cascade = common_prefix_len > 0 if use_cascade: # TODO: Optimize. cu_prefix_query_lens = torch.tensor( @@ -581,6 +525,90 @@ class GPUModelRunner(LoRAModelRunnerMixin): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: int, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * self.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = FlashAttentionBackend.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + use_alibi=False, # FIXME + use_sliding_window=self.sliding_window is not None, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 num_reqs = self.input_batch.num_reqs -- GitLab From 407b5537db02da122abf863673fc6cb76795e8bd Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Sat, 8 Feb 2025 17:15:15 +0800 Subject: [PATCH 033/253] [Build] Make pypi install work on CPU platform (#12874) --- setup.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a4043c43a..dc517dafa 100755 --- a/setup.py +++ b/setup.py @@ -47,6 +47,11 @@ elif not (sys.platform.startswith("linux") "Building on %s, " "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" +elif (sys.platform.startswith("linux") and torch.version.cuda is None + and os.getenv("VLLM_TARGET_DEVICE") is None): + # if cuda is not available and VLLM_TARGET_DEVICE is not set, + # fallback to cpu + VLLM_TARGET_DEVICE = "cpu" MAIN_CUDA_VERSION = "12.1" @@ -482,7 +487,6 @@ def get_vllm_version() -> str: version = get_version( write_to="vllm/_version.py", # TODO: move this to pyproject.toml ) - sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): @@ -520,7 +524,8 @@ def get_vllm_version() -> str: elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): - version += f"{sep}cpu" + if envs.VLLM_TARGET_DEVICE == "cpu": + version += f"{sep}cpu" elif _is_xpu(): version += f"{sep}xpu" else: -- GitLab From 2880e21e3d2513c89bd63ac05b718e0c0a50e4e4 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Sat, 8 Feb 2025 14:45:30 +0530 Subject: [PATCH 034/253] [Hardware][Intel-Gaudi] Enable long-contexts + LoRA support for Intel Gaudi (#12812) Signed-off-by: Sanju C Sudhakaran --- vllm/lora/punica_wrapper/punica_hpu.py | 57 ++++++++++++++++++- .../model_executor/layers/rotary_embedding.py | 3 +- vllm/worker/hpu_model_runner.py | 17 +++++- 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index 51e1bfab3..3661a7214 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union, final +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, dispatch_bgmv_linear) from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext @final @@ -19,6 +25,55 @@ class PunicaWrapperHPU(PunicaWrapperBase): PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, max_batches, device) + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: List[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: List[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index ec204b32f..5d7f9396c 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -206,9 +206,10 @@ class RotaryEmbedding(CustomOp): ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - positions = positions.flatten() if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions).view( num_tokens, 1, -1) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b846d4387..774049a52 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -639,12 +639,25 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': -- GitLab From 7e1837676a3230b1b392d7699771cdb3f3407242 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 8 Feb 2025 14:45:44 +0530 Subject: [PATCH 035/253] [misc] Add LoRA to benchmark_serving (#12898) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- benchmarks/benchmark_serving.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e934d228f..1044bef59 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -537,6 +537,7 @@ async def benchmark( ignore_eos: bool, goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], + lora_modules: Optional[List[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -562,6 +563,7 @@ async def benchmark( multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) + test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError( @@ -570,6 +572,11 @@ async def benchmark( else: print("Initial test run completed. Starting main benchmark run...") + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, @@ -616,8 +623,13 @@ async def benchmark( tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request - request_func_input = RequestFuncInput(model=model_id, - model_name=model_name, + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -900,6 +912,7 @@ def main(args: argparse.Namespace): ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, )) # Save config and results to json @@ -1237,5 +1250,12 @@ if __name__ == "__main__": "If not specified, the model name will be the " "same as the ``--model`` argument. ") + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + args = parser.parse_args() main(args) -- GitLab From 011e612d92c25cb1a3cbfa1536cb8edd871d7715 Mon Sep 17 00:00:00 2001 From: Jun Duan Date: Sat, 8 Feb 2025 04:16:42 -0500 Subject: [PATCH 036/253] [Misc] Log time consumption on weight downloading (#12926) --- vllm/model_executor/model_loader/weight_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index cade0a1dd..68ade319d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,6 +6,7 @@ import hashlib import json import os import tempfile +import time from collections import defaultdict from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union @@ -14,7 +15,8 @@ import gguf import huggingface_hub.constants import numpy as np import torch -from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir, + snapshot_download) from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -253,6 +255,8 @@ def download_weights_from_hf( # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): + start_size = scan_cache_dir().size_on_disk + start_time = time.perf_counter() hf_folder = snapshot_download( model_name_or_path, allow_patterns=allow_patterns, @@ -262,6 +266,11 @@ def download_weights_from_hf( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) + end_time = time.perf_counter() + end_size = scan_cache_dir().size_on_disk + if end_size != start_size: + logger.info("Time took to download weights for %s: %.6f seconds", + model_name_or_path, end_time - start_time) return hf_folder -- GitLab From c45d398e6f0e1c78b958d2e0346b860c23444af9 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Sat, 8 Feb 2025 01:41:35 -0800 Subject: [PATCH 037/253] [CI] Resolve transformers-neuronx version conflict (#12925) --- .buildkite/run-neuron-test.sh | 3 --- Dockerfile.neuron | 8 +++++++- requirements-neuron.txt | 1 - setup.py | 7 +------ 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 1ad77cf50..55c374fcc 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -29,9 +29,6 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then docker image prune -f # Remove unused volumes / force the system prune for old images as well. docker volume prune -f && docker system prune -f - # Remove huggingface model artifacts and compiler cache - rm -rf "${HF_MOUNT:?}/*" - rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*" echo "$current_time" > /tmp/neuron-docker-build-timestamp fi else diff --git a/Dockerfile.neuron b/Dockerfile.neuron index e9cb82889..27658d836 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -23,10 +23,12 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas RUN python3 -m pip install sentencepiece transformers==4.45.2 -U -RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install pytest +# uninstall transformers-neuronx package explicitly to avoid version conflict +RUN python3 -m pip uninstall -y transformers-neuronx + COPY . . ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -43,6 +45,10 @@ RUN --mount=type=bind,source=.git,target=.git \ # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils +# install transformers-neuronx package as an optional dependencies (for V0) +# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict +RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps + # overwrite entrypoint to run bash script RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 5e08d101f..09820c73e 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -2,6 +2,5 @@ -r requirements-common.txt # Dependencies for Neuron devices -transformers-neuronx >= 0.13.0 torch-neuronx >= 2.5.0 neuronx-cc diff --git a/setup.py b/setup.py index dc517dafa..3e2adadf6 100755 --- a/setup.py +++ b/setup.py @@ -374,12 +374,7 @@ def _is_hip() -> bool: def _is_neuron() -> bool: - torch_neuronx_installed = True - try: - subprocess.run(["neuron-ls"], capture_output=True, check=True) - except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): - torch_neuronx_installed = False - return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" + return VLLM_TARGET_DEVICE == "neuron" def _is_tpu() -> bool: -- GitLab From 256a2d29dc2358d7c0a5d38c0faf152095335929 Mon Sep 17 00:00:00 2001 From: Jun Duan Date: Sat, 8 Feb 2025 04:42:15 -0500 Subject: [PATCH 038/253] [Doc] Correct HF repository for TeleChat2 models (#12949) --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 32f3e9def..38f36b54d 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -429,7 +429,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `TeleChat2ForCausalLM` * TeleChat2 - * `TeleAI/TeleChat2-3B`, `TeleAI/TeleChat2-7B`, `TeleAI/TeleChat2-35B`, etc. + * `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. * ✅︎ * ✅︎ - * `XverseForCausalLM` -- GitLab From 4c8dd12ef3474e43c614a229dabe85cc47432cf8 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 8 Feb 2025 20:24:47 +0800 Subject: [PATCH 039/253] [Misc] Add qwen2.5-vl BNB support (#12944) --- vllm/model_executor/models/qwen2_5_vl.py | 59 ++++++++++++------------ 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1f350ab20..d4c48dbda 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import parallel_state +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module): ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size) self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, @@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module): f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + def forward( self, x: torch.Tensor, @@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module): # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - x = x.view(*new_x_shape) - - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] - q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() @@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - if name.endswith("qkv.weight"): - visual_num_heads = self.num_heads - visual_embed_dim = self.hidden_size - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif name.endswith("qkv.bias"): - visual_num_heads = self.num_heads - visual_embed_dim = self.hidden_size - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) -- GitLab From 8a69e0e20e72d429aaf379ae7647f0434a0e9c9e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 8 Feb 2025 20:25:15 +0800 Subject: [PATCH 040/253] [CI/Build] Auto-fix Markdown files (#12941) --- .buildkite/nightly-benchmarks/README.md | 46 ++++++--------- .../nightly-benchmarks/nightly-annotation.md | 21 ++++--- .../nightly-descriptions.md | 6 +- .../performance-benchmarks-descriptions.md | 10 +--- .github/PULL_REQUEST_TEMPLATE.md | 3 +- .pre-commit-config.yaml | 2 +- CODE_OF_CONDUCT.md | 1 - README.md | 14 +++-- benchmarks/README.md | 2 + csrc/quantization/cutlass_w8a8/Epilogues.md | 44 ++++++++++---- csrc/quantization/machete/Readme.md | 14 ++--- .../installation/gpu/rocm.inc.md | 9 ++- docs/source/serving/engine_args.md | 4 +- .../offline_inference/openai/openai_batch.md | 59 +++++++++---------- .../offline_inference/profiling_tpu/README.md | 6 +- examples/online_serving/chart-helm/README.md | 2 +- examples/online_serving/opentelemetry/Otel.md | 32 ++++++---- .../prometheus_grafana/README.md | 14 +++-- examples/other/logging_configuration.md | 5 -- vllm/distributed/kv_transfer/README.md | 5 +- 20 files changed, 158 insertions(+), 141 deletions(-) diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index fbf41eb10..d3f5fc5cd 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -1,15 +1,13 @@ # vLLM benchmark suite - ## Introduction This directory contains two sets of benchmark for vllm. + - Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance - Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. - -See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. - +See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. ## Performance benchmark quick overview @@ -19,17 +17,14 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan **For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run. - ## Nightly benchmark quick overview -**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. +**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. **Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy. **Benchmarking Duration**: about 3.5hrs. - - ## Trigger the benchmark Performance benchmark will be triggered when: @@ -39,16 +34,11 @@ Performance benchmark will be triggered when: Nightly benchmark will be triggered when: - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. - - - ## Performance benchmark details - See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. - -#### Latency test +### Latency test Here is an example of one test inside `latency-tests.json`: @@ -68,23 +58,25 @@ Here is an example of one test inside `latency-tests.json`: ``` In this example: -- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` + +- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. +### Throughput test -#### Throughput test The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. -#### Serving test +### Serving test + We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: -``` +```json [ { "test_name": "serving_llama8B_tp1_sharegpt", @@ -109,6 +101,7 @@ We test the throughput by using `benchmark_serving.py` with request rate = inf t ``` Inside this example: + - The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. - The `server-parameters` includes the command line arguments for vLLM server. - The `client-parameters` includes the command line arguments for `benchmark_serving.py`. @@ -118,36 +111,33 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. -#### Visualizing the results +### Visualizing the results + The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. - - ## Nightly test details See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. +### Workflow -#### Workflow - -- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. +- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. - Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. - The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. - At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. -#### Nightly tests +### Nightly tests In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark. -#### Docker containers +### Docker containers The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). - diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md index 1e3379384..e43ea765f 100644 --- a/.buildkite/nightly-benchmarks/nightly-annotation.md +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -9,20 +9,19 @@ This file contains the downloading link for benchmarking results. Please download the visualization scripts in the post - ## Results reproduction - Find the docker we use in `benchmarking pipeline` - Deploy the docker, and inside the docker: - - Download `nightly-benchmarks.zip`. - - In the same folder, run the following code -``` -export HF_TOKEN= -apt update -apt install -y git -unzip nightly-benchmarks.zip -VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh -``` + - Download `nightly-benchmarks.zip`. + - In the same folder, run the following code: -And the results will be inside `./benchmarks/results`. + ```console + export HF_TOKEN= + apt update + apt install -y git + unzip nightly-benchmarks.zip + VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh + ``` +And the results will be inside `./benchmarks/results`. diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 7dec7a0fe..5f003f42f 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -2,6 +2,7 @@ # Nightly benchmark This benchmark aims to: + - Provide performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and SGLang) leads in performance in what workload. - Be reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions. @@ -9,7 +10,6 @@ Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html) Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) - ## Setup - Docker images: @@ -33,7 +33,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - Queries are randomly sampled, and arrival patterns are determined via Poisson process, but all with fixed random seed. - Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). -# Known issues +## Known issues - TRT-LLM crashes with Llama 3.1 8B [issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105). -- TGI does not support `ignore-eos` flag. \ No newline at end of file +- TGI does not support `ignore-eos` flag. diff --git a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md index da32d1f07..cacaef986 100644 --- a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md +++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md @@ -7,10 +7,8 @@ - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: end-to-end latency (mean, median, p99). - {latency_tests_markdown_table} - ## Throughput tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). @@ -19,10 +17,8 @@ - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: throughput. - {throughput_tests_markdown_table} - ## Serving tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). @@ -33,13 +29,11 @@ - We also added a speculative decoding test for llama-3 70B, under QPS 2 - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). - {serving_tests_markdown_table} - ## json version of the benchmarking tables -This section contains the data of the markdown tables above in JSON format. +This section contains the data of the markdown tables above in JSON format. You can load the benchmarking tables into pandas dataframes as follows: ```python @@ -54,9 +48,9 @@ serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) ``` The json string for all benchmarking tables: + ```json {benchmarking_results_in_json_string} ``` You can also check the raw experiment data in the Artifact tab of the Buildkite page. - diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 51a73c857..a20c5baf8 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,4 +2,5 @@ FILL IN THE PR DESCRIPTION HERE FIX #xxxx (*link existing issues this PR will resolve*) -**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html ** + +**BEFORE SUBMITTING, PLEASE READ ** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fb74ab9b..118451593 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: rev: v0.9.27 hooks: - id: pymarkdown - files: docs/.* + args: [fix] - repo: https://github.com/rhysd/actionlint rev: v1.7.7 hooks: diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 1a9596841..5268ff135 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -125,4 +125,3 @@ Community Impact Guidelines were inspired by For answers to common questions about this code of conduct, see the [Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at [Contributor Covenant translations](https://www.contributor-covenant.org/translations). - diff --git a/README.md b/README.md index cd0b1c517..f04acf09c 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 + - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! @@ -33,7 +34,9 @@ Easy, fast, and cheap LLM serving for everyone - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). --- + ## About + vLLM is a fast and easy-to-use library for LLM inference and serving. Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. @@ -127,6 +130,7 @@ We also have an official fundraising venue through [OpenCollective](https://open ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): + ```bibtex @inproceedings{kwon2023efficient, title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, @@ -138,11 +142,11 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Contact Us -* For technical questions and feature requests, please use Github issues or discussions. -* For discussing with fellow users and coordinating contributions and development, please use Slack. -* For security disclosures, please use Github's security advisory feature. -* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. +- For technical questions and feature requests, please use Github issues or discussions. +- For discussing with fellow users and coordinating contributions and development, please use Slack. +- For security disclosures, please use Github's security advisory feature. +- For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. ## Media Kit -* If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). +- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). diff --git a/benchmarks/README.md b/benchmarks/README.md index 2aa4a2850..890a2525b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -3,6 +3,7 @@ ## Downloading the ShareGPT dataset You can download the dataset by running: + ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ``` @@ -11,6 +12,7 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts will ignore a datapoint if the referred image is missing. + ```bash wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json mkdir coco -p diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md index aae04157b..a30e1fdf3 100644 --- a/csrc/quantization/cutlass_w8a8/Epilogues.md +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -1,17 +1,19 @@ # CUTLASS Epilogues ## Introduction -This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. Currently, we only support symmetric quantization for weights, and symmetric and asymmetric quantization for activations. Both can be quantized per-tensor or per-channel (weights) / per-token (activations). There are 4 epilogues: -1. ScaledEpilogue: symmetric quantization for activations, no bias. -1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. -1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. -1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +1. `ScaledEpilogue`: symmetric quantization for activations, no bias. +1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias. +1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias. +1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias. We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. Instead, if no bias is passed, the epilogue will use 0 as the bias. @@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following ```math A = s_a (\widehat A - J_a z_a) ``` + ```math B = s_b \widehat B ``` + ```math D = A B + C ``` + ```math D = s_a s_b \widehat D + C ``` @@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows: ```math A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B ``` + ```math A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) ``` + ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` @@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of ## Epilogues -### ScaledEpilogue +### `ScaledEpilogue` + This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D ``` + ```math D = s_a s_b \widehat A \widehat B ``` @@ -79,44 +89,51 @@ Epilogue parameters: - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). -### ScaledEpilogueBias +### `ScaledEpilogueBias` + This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \widehat A \widehat B + C ``` - Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `bias` is the bias, is always per-channel (row-vector). -### ScaledEpilogueAzp +### `ScaledEpilogueAzp` + This epilogue computes the asymmetric per-tensor quantization for activations with bias. The output of the GEMM is: ```math \widehat D = \widehat A \widehat B - z_a J_a \widehat B ``` + ```math D = s_a s_b \widehat D + C ``` + ```math D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C ``` -Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. That is precomputed and stored in `azp_with_adj` as a row-vector. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-tensor as the zero-points are per-tensor. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -125,13 +142,15 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. -### ScaledEpilogueAzpPerToken +### `ScaledEpilogueAzpPerToken` + This epilogue computes the asymmetric per-token quantization for activations with bias. The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. Epilogue parameters: + - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - Generally this will be per-token as the zero-points are per-token. - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). @@ -142,6 +161,7 @@ Epilogue parameters: To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): -``` + +```math out = scale_a * scale_b * (Dq - azp_adj * azp) + bias ``` diff --git a/csrc/quantization/machete/Readme.md b/csrc/quantization/machete/Readme.md index 9ddf8da99..6ffb2416b 100644 --- a/csrc/quantization/machete/Readme.md +++ b/csrc/quantization/machete/Readme.md @@ -6,25 +6,25 @@ Machete is a spiritual successor to the Marlin kernel but optimized for Hopper a Machete effectively performs -``` +```python scale_type = w_s.dtype compute_type = a.dtype out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a ``` -Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and +Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and `w_z` is the quantization zeropoints. -> **_NOTE:_** `w_z` is added after the scales so we can +> **_NOTE:_** `w_z` is added after the scales so we can use FMA operations, but this means they must have the scales pre-applied if the -supplied zeropoints assume that they will be subtracted before the scales are +supplied zeropoints assume that they will be subtracted before the scales are applied. ## API The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: -``` +```python from vllm import _custom_ops as ops ... @@ -40,6 +40,6 @@ output = ops.machete_gemm( ## Code Generation -Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. +Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. -New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. \ No newline at end of file +New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index 336d578de..7004313c9 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -93,12 +93,11 @@ Currently, there are no pre-built ROCm wheels. This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. - :::{tip} - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - - The ROCm version of PyTorch, ideally, should match the ROCm driver version. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. + - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. + - The ROCm version of PyTorch, ideally, should match the ROCm driver version. ::: :::{tip} diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md index 827c25b50..f4587b94e 100644 --- a/docs/source/serving/engine_args.md +++ b/docs/source/serving/engine_args.md @@ -4,7 +4,7 @@ Below, you can find an explanation of every engine argument for vLLM: - + ```{eval-rst} .. argparse:: :module: vllm.engine.arg_utils @@ -17,7 +17,7 @@ Below, you can find an explanation of every engine argument for vLLM: Below are the additional arguments related to the asynchronous engine: - + ```{eval-rst} .. argparse:: :module: vllm.engine.arg_utils diff --git a/examples/offline_inference/openai/openai_batch.md b/examples/offline_inference/openai/openai_batch.md index 953e6ef13..d271573aa 100644 --- a/examples/offline_inference/openai/openai_batch.md +++ b/examples/offline_inference/openai/openai_batch.md @@ -5,50 +5,49 @@ This is a guide to performing batch inference using the OpenAI batch file format ``` ## File Format - + The OpenAI batch file format consists of a series of json objects on new lines. - + [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai/openai_example_batch.jsonl) - + Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. - + ```{note} We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` endpoints (completions coming soon). ``` - + ## Pre-requisites * The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) - Install the token on your machine (Run `huggingface-cli login`). - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. - - + ## Example 1: Running with a local file ### Step 1: Create your batch file To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -``` +```console wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -``` +```console $ cat offline_inference/openai/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` ### Step 2: Run the batch - + The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` -``` +```console python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` @@ -56,7 +55,7 @@ python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_e You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} {"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} @@ -68,7 +67,7 @@ The batch runner supports remote input and output urls that are accessible via h For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl`, you can run -``` +```console python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` @@ -80,7 +79,7 @@ To integrate with cloud blob storage, we recommend using presigned urls. ### Additional prerequisites -* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). +* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). * The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3. - [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html). * The `boto3` python package (Run `pip install boto3`) to generate presigned urls. @@ -89,13 +88,13 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. -``` +```console wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this -``` +```console $ cat offline_inference/openai/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} @@ -103,7 +102,7 @@ $ cat offline_inference/openai/openai_example_batch.jsonl Now upload your batch file to your S3 bucket. -``` +```console aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` @@ -111,9 +110,9 @@ aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_ Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. -(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py) +(The script is adapted from ) -``` +```python import boto3 from botocore.exceptions import ClientError @@ -149,7 +148,7 @@ print(f"{output_url=}") This script should output -``` +```text input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' ``` @@ -158,7 +157,7 @@ output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AW You can now run the batch runner, using the urls generated in the previous section. -``` +```console python -m vllm.entrypoints.openai.run_batch \ -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ @@ -169,7 +168,7 @@ python -m vllm.entrypoints.openai.run_batch \ Your results are now on S3. You can view them in your terminal by running -``` +```console aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - ``` @@ -180,10 +179,10 @@ aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - * Ensure you are using `vllm >= 0.5.5`. ### Step 1: Create your batch file - + Add embedding requests to your batch file. The following is an example: - -``` + +```text {"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}} {"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}} ``` @@ -198,7 +197,7 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null} ... @@ -211,10 +210,10 @@ $ cat results.jsonl * Ensure you are using `vllm >= 0.7.0`. ### Step 1: Create your batch file - + Add score requests to your batch file. The following is an example: - -``` + +```text {"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} ``` @@ -229,7 +228,7 @@ You can run the batch using the same command as in earlier examples. You can check your results by running `cat results.jsonl` -``` +```console $ cat results.jsonl {"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} {"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null} diff --git a/examples/offline_inference/profiling_tpu/README.md b/examples/offline_inference/profiling_tpu/README.md index 08efa63dc..6595efec4 100644 --- a/examples/offline_inference/profiling_tpu/README.md +++ b/examples/offline_inference/profiling_tpu/README.md @@ -29,7 +29,6 @@ python3 profiling.py \ --profile-result-dir profiles ``` - ### Generate Decode Trace This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill). @@ -51,17 +50,18 @@ python3 profiling.py \ --max-model-len 2048 --tensor-parallel-size 8 ``` - ## Visualizing the profiles Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). Here are most likely the dependencies you need to install: + ```bash pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources ``` Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser: + ```bash tensorboard --logdir profiles/ --port 6006 -``` \ No newline at end of file +``` diff --git a/examples/online_serving/chart-helm/README.md b/examples/online_serving/chart-helm/README.md index 6aa126d4f..bfe81121d 100644 --- a/examples/online_serving/chart-helm/README.md +++ b/examples/online_serving/chart-helm/README.md @@ -18,4 +18,4 @@ This directory contains a Helm chart for deploying the vllm application. The cha - templates/poddisruptionbudget.yaml: Template for Pod Disruption Budget. - templates/pvc.yaml: Template for Persistent Volume Claims. - templates/secrets.yaml: Template for Kubernetes Secrets. -- templates/service.yaml: Template for creating Services. \ No newline at end of file +- templates/service.yaml: Template for creating Services. diff --git a/examples/online_serving/opentelemetry/Otel.md b/examples/online_serving/opentelemetry/Otel.md index 96d1f96bf..af0034007 100644 --- a/examples/online_serving/opentelemetry/Otel.md +++ b/examples/online_serving/opentelemetry/Otel.md @@ -1,7 +1,8 @@ # Setup OpenTelemetry POC 1. Install OpenTelemetry packages: - ``` + + ```console pip install \ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ 'opentelemetry-api>=1.26.0,<1.27.0' \ @@ -10,7 +11,8 @@ ``` 1. Start Jaeger in a docker container: - ``` + + ```console # From: https://www.jaegertracing.io/docs/1.57/getting-started/ docker run --rm --name jaeger \ -e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \ @@ -28,19 +30,23 @@ ``` 1. In a new shell, export Jaeger IP: - ``` + + ```console export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 ``` + Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM: - ``` + + ```console export OTEL_SERVICE_NAME="vllm-server" export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` 1. In a new shell, send requests with trace context from a dummy client - ``` + + ```console export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger) export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317 export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true @@ -48,7 +54,7 @@ python dummy_client.py ``` -1. Open Jaeger webui: http://localhost:16686/ +1. Open Jaeger webui: In the search pane, select `vllm-server` service and hit `Find Traces`. You should get a list of traces, one for each request. ![Traces](https://i.imgur.com/GYHhFjo.png) @@ -57,26 +63,32 @@ ![Spans details](https://i.imgur.com/OPf6CBL.png) ## Exporter Protocol + OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter. By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows: -``` + +```console export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` ## Instrumentation of FastAPI + OpenTelemetry allows automatic instrumentation of FastAPI. + 1. Install the instrumentation library - ``` + + ```console pip install opentelemetry-instrumentation-fastapi ``` 1. Run vLLM with `opentelemetry-instrument` - ``` + + ```console opentelemetry-instrument vllm serve facebook/opt-125m ``` 1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI. -![FastAPI Spans](https://i.imgur.com/hywvoOJ.png) \ No newline at end of file +![FastAPI Spans](https://i.imgur.com/hywvoOJ.png) diff --git a/examples/online_serving/prometheus_grafana/README.md b/examples/online_serving/prometheus_grafana/README.md index 4a85f953b..6df959451 100644 --- a/examples/online_serving/prometheus_grafana/README.md +++ b/examples/online_serving/prometheus_grafana/README.md @@ -1,14 +1,16 @@ -# Prometheus and Grafana +# Prometheus and Grafana -This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites. +This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites. + +Install: -Install: - [`docker`](https://docs.docker.com/engine/install/) - [`docker compose`](https://docs.docker.com/compose/install/linux/#install-using-the-repository) ## Launch Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint: + ```bash vllm serve mistralai/Mistral-7B-v0.1 \ --max-model-len 2048 \ @@ -16,11 +18,13 @@ vllm serve mistralai/Mistral-7B-v0.1 \ ``` Launch Prometheus and Grafana servers with `docker compose`: + ```bash docker compose up ``` Submit some sample requests to the server: + ```bash wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json @@ -41,13 +45,13 @@ Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the de ### Add Prometheus Data Source -Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus. +Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus. On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`. Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.". -### Import Dashboard +### Import Dashboard Navigate to [`http://localhost:3000/dashboard/import`](http://localhost:3000/dashboard/import), upload `grafana.json`, and select the `prometheus` datasource. You should see a screen that looks like the following: diff --git a/examples/other/logging_configuration.md b/examples/other/logging_configuration.md index 9ac8b13cd..acd9c1f2b 100644 --- a/examples/other/logging_configuration.md +++ b/examples/other/logging_configuration.md @@ -15,7 +15,6 @@ more-complex-and-more-flexible. - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and set `VLLM_LOGGING_CONFIG_PATH=` - ## Logging Configuration Environment Variables ### `VLLM_CONFIGURE_LOGGING` @@ -45,7 +44,6 @@ schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema- If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is disabled, an error will occur while starting vLLM. - ## Examples ### Example 1: Customize vLLM root logger @@ -98,7 +96,6 @@ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ### Example 2: Silence a particular vLLM logger To silence a particular vLLM logger, it is necessary to provide custom logging @@ -153,7 +150,6 @@ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ### Example 3: Disable vLLM default logging configuration To disable vLLM's default logging configuration and silence all vLLM loggers, @@ -166,7 +162,6 @@ VLLM_CONFIGURE_LOGGING=0 \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` - ## Additional resources - [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md index e20c992a3..c408d4a67 100644 --- a/vllm/distributed/kv_transfer/README.md +++ b/vllm/distributed/kv_transfer/README.md @@ -14,8 +14,8 @@ The KV cache transfer contains three layer of abstractions: Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. -NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed -communication service already supports key-value-based lookup (like redis or +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or RDMA database). NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. @@ -27,4 +27,3 @@ The example usage is in [this file](../../../examples/online_serving/disaggregat Here is the diagram of how we run disaggretgated prefilling. ![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) - -- GitLab From 913df14da3014d9432bfd8a5114f845ab567b1c6 Mon Sep 17 00:00:00 2001 From: shangmingc Date: Sat, 8 Feb 2025 22:46:19 +0800 Subject: [PATCH 041/253] [Bugfix] Remove unused seq_group_metadata_list from ModelInputForGPU (#12935) Signed-off-by: Shangming Cai --- vllm/worker/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12baecde6..c7814f173 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -98,7 +98,6 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: -- GitLab From fe743b798dfa56aea3e2cb7182365ba3495489ee Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Feb 2025 00:06:56 +0800 Subject: [PATCH 042/253] [bugfix] fix early import of flash attention (#12959) Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 13 +++++++------ vllm/attention/backends/mla/utils.py | 5 +++-- vllm/attention/backends/utils.py | 14 ++++++-------- vllm/v1/attention/backends/flash_attn.py | 7 ++++--- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 971fe4116..5aca10079 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,8 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import ( - PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState, - compute_slot_mapping, compute_slot_mapping_start_idx, + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_flash_attn_version, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) @@ -640,6 +640,7 @@ class FlashAttentionImpl(AttentionImpl): f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -759,7 +760,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # prefix-enabled attention @@ -782,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl): block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) if decode_meta := attn_metadata.decode_metadata: @@ -811,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -832,7 +833,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index c22f7e921..a41140ec8 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -12,7 +12,7 @@ from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -181,6 +181,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: @@ -515,7 +516,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e8a344341..5c1f9916e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens( num_decode_query_tokens) -try: - from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) +def get_flash_attn_version(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) - def flash_attn_version(): # if hopper default to FA3, otherwise stick to FA2 for now # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both @@ -610,7 +610,5 @@ try: assert is_fa_version_supported(fa_version) return fa_version - - VLLM_FLASH_ATTN_VERSION = flash_attn_version() -except (ImportError, AssertionError): - VLLM_FLASH_ATTN_VERSION = None + except (ImportError, AssertionError): + return None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 204afc9f4..5cb1e2fd2 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,7 +10,7 @@ import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import cdiv from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl): "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output @@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output -- GitLab From 86222a3dab50b66bb0bff17a94b629aa59c3ed57 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 9 Feb 2025 04:32:16 +0800 Subject: [PATCH 043/253] [VLM] Merged multi-modal processor for GLM4V (#12449) Signed-off-by: Jee Jee Li --- docs/source/models/supported_models.md | 2 +- examples/offline_inference/vision_language.py | 4 +- .../multimodal/processing/test_common.py | 1 + vllm/model_executor/models/chatglm.py | 382 ++++++++++-------- 4 files changed, 222 insertions(+), 167 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 38f36b54d..91e6c42d5 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ * `THUDM/glm-4v-9b` etc. * ✅︎ * ✅︎ - * + * ✅︎ - * `H2OVLChatModel` * H2OVL * T + IE+ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 436c36570..9a4183106 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str): trust_remote_code=True, enforce_eager=True, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = question + prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ + {question}<|assistant|>" + stop_token_ids = [151329, 151336, 151338] return llm, prompt, stop_token_ids diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 77cf3442d..8658e60bc 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -147,6 +147,7 @@ def _test_processing_correctness( "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", "adept/fuyu-8b", + "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index a31648675..9ee9e9ca8 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,20 +4,21 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from array import array -from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict) +from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple, + TypedDict, Union) import torch -from PIL import Image from torch import nn from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,73 +36,55 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BatchFeature, + BoundPromptReplacement, + MultiModalFieldConfig, + PlaceholderFeaturesInfo, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) +IMAGE_TOKEN_ID = 151329 -def calculate_image_placeholder(vision_config): - return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 +def build_normalization_transform(image_size: int) -> transforms.Compose: + """ + Build a normalization transform which can be applied to one or + more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. + """ -def mm_input_mapper_for_glmv( - ctx: InputContext, - data: ModalityData[object], -) -> Dict: - model_config = ctx.model_config - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - if tokenizer is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": data - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True).data - except Exception: - logger.error("Failed to process image (%s)", data) - raise - pixel_values = raw_batch_data['images'] - - return MultiModalKwargs({'pixel_values': pixel_values}) - - -def merge_glm_vision_embeddings( - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - boi_token_id: int, - eoi_token_id: int, -) -> torch.Tensor: - - boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0] - eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0] - - mask = torch.zeros_like(input_ids, dtype=torch.bool) - - for boi_pos, eoi_pos in zip(boi_positions, eoi_positions): - assert boi_pos < eoi_pos - mask[boi_pos:eoi_pos + 1] = True - inputs_embeds[mask] = vision_embeddings.view(-1, - vision_embeddings.shape[-1]) - return inputs_embeds + return transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ]) + + +def calculate_image_placeholder(vision_config): + return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 class GLMImagePixelInputs(TypedDict): @@ -109,120 +92,177 @@ class GLMImagePixelInputs(TypedDict): """Shape: `(batch_size, num_channels, height, width)`""" -def get_max_glmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(ChatGLMConfig) +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. - vision_config = getattr(hf_config, 'vision_config', None) - if vision_config is None: - return 1 - elif isinstance(vision_config, dict): - return calculate_image_placeholder(vision_config) + """ - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer -def dummy_data_for_glmv(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + if hasattr(self.config, "vision_config"): + self.image_transform = build_normalization_transform( + config.vision_config["image_size"]) + else: + self.image_transform = None - if vision_config is None: - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) - seq_data = SequenceData(token_ids) - return DummyData(seq_data, None) - elif isinstance(vision_config, dict): - image_size = vision_config["image_size"] - image_placeholder_length = calculate_image_placeholder(vision_config) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + - [0] * image_placeholder_length + - [hf_config.eoi_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0] * (seq_len - image_placeholder_length - 2)) - seq_data = SequenceData(token_ids) + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + text_inputs = self.tokenizer(text) + if len(images) == 0: + image_inputs = {} + else: + if self.image_transform is None: + raise ValueError("This model does not support image inputs") + + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) - mm_data = { - "image": Image.new("RGB", (image_size, image_size), color=0) - } - return DummyData(seq_data, mm_data) +class GLM4VProcessingInfo(BaseProcessingInfo): - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def __init__(self, ctx): + super().__init__(ctx) + self._pre_calculate() + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} -def find_all_positions(input_ids: List[int], target: int) -> List[int]: - return [index for index, value in enumerate(input_ids) if value == target] + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.image_token_num + 2} -def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + def _pre_calculate(self): + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + self.image_token_num = calculate_image_placeholder(vision_config) + self.image_size = vision_config["image_size"] - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + def get_num_image_tokens(self) -> int: + return self.image_token_num + 2 - if vision_config is None: - return inputs - elif isinstance(vision_config, dict): - image_placeholder_length = calculate_image_placeholder(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def get_image_size(self) -> ImageSize: - input_ids = inputs["prompt_token_ids"] + return ImageSize(height=self.image_size, width=self.image_size) - tokenizer = cached_get_tokenizer( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code) + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": multi_modal_data["image"], - "content": inputs['prompt'], - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True, - ).data - except Exception: - logger.error("Failed to process content (%s)", inputs['prompt']) - raise - input_ids = raw_batch_data['input_ids'][0].tolist() - boi_token_id = hf_config.boi_token_id - eoi_token_id = hf_config.eoi_token_id - boi_positions = find_all_positions(input_ids, boi_token_id) - eoi_positions = find_all_positions(input_ids, eoi_token_id) +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - assert len(boi_positions) == len(eoi_positions) + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + target_width, target_height = self.info.get_image_size() - new_input_ids = [] - final_processed_position = 0 + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + return ProcessorInputs( + prompt_text=text, + mm_data=mm_data, + ) - for boi_position, eoi_position in zip(boi_positions, eoi_positions): - assert boi_position < eoi_position - new_input_ids.extend(input_ids[final_processed_position:boi_position + - 1]) - new_input_ids.extend([input_ids[boi_position + 1]] * - image_placeholder_length) - final_processed_position = eoi_position - new_input_ids.extend(input_ids[final_processed_position:]) +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement(item_idx: int): + image_tokens = self.info.image_token_num + return [IMAGE_TOKEN_ID] * image_tokens + + return [ + PromptReplacement( + modality="image", + target=[IMAGE_TOKEN_ID], + replacement=get_replacement, + ), + ] - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_input_ids) + def _apply_prompt_replacements( + self, + token_ids: list[int], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids=token_ids, + mm_prompt_repls=mm_prompt_repls, + mm_item_counts=mm_item_counts, + ) + hf_config = self.info.get_hf_config() + boi_token_id = hf_config.boi_token_id + eoi_token_id = hf_config.eoi_token_id + placeholders = { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=p.start_idx - 1, + tokens=[boi_token_id] + p.tokens + [eoi_token_id], + ) for p in ps + ] + for modality, ps in placeholders.items() + } - return token_inputs( - prompt_token_ids=new_input_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - ) + return token_ids, text, placeholders class GLMAttention(nn.Module): @@ -572,12 +612,16 @@ class ChatGLMModel(nn.Module): ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) if multimodal_embeddings is not None: - inputs_embeds = merge_glm_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, inputs_embeds=inputs_embeds, - vision_embeddings=multimodal_embeddings, - boi_token_id=self.config.boi_token_id, - eoi_token_id=self.config.eoi_token_id) + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + IMAGE_TOKEN_ID, + self.config.eoi_token_id, + ], + ) return inputs_embeds def forward( @@ -593,14 +637,12 @@ class ChatGLMModel(nn.Module): # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. - if intermediate_tensors is None and inputs_embeds is None: + if intermediate_tensors is not None: + inputs_embeds = intermediate_tensors["hidden_states"] + elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) - input_ids = None - else: - inputs_embeds = intermediate_tensors["hidden_states"] - # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, @@ -763,11 +805,21 @@ class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): connector="transformer.vision.linear_proj", tower_model="transformer.vision.transformer") + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + return self.transformer.get_multimodal_embeddings(**kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids, + multimodal_embeddings) + -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not -- GitLab From 870c37481e4d9dbcd548344e1eee6bd83993388a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Feb 2025 12:48:30 -0800 Subject: [PATCH 044/253] [V1][Minor] Remove outdated comment (#12968) Signed-off-by: Woosuk Kwon --- vllm/v1/core/kv_cache_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index eefc2e19c..f8d08d0e4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -205,8 +205,6 @@ class KVCacheManager: # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. self.max_num_blocks_per_req - len(req_blocks), ) assert num_new_blocks > 0 -- GitLab From d366ccc4e391ab772711f0832e8ea61d90d5def3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 8 Feb 2025 22:12:53 +0100 Subject: [PATCH 045/253] [RFC] [Mistral] FP8 format (#10130) Signed-off-by: mgoin Co-authored-by: mgoin --- vllm/model_executor/models/llama.py | 20 ++++++++-- vllm/model_executor/models/pixtral.py | 7 +++- vllm/transformers_utils/config.py | 37 ++++++++++++++++--- vllm/transformers_utils/tokenizers/mistral.py | 3 +- 4 files changed, 55 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 866c69234..2ff52dd78 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): mistral_mapping = { "layers": "model.layers", "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale", + "kv_fake_quantizer.qscale_act": "kv_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", @@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): modules = name.split(".") # rotary embeds should be sliced - if "wk" in modules: + if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) - elif "wq" in modules: + elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, self.config.num_attention_heads) - for item in modules: - if item in mapping and mapping[item] not in name: + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = (f"{item}.{next_item}" + if next_item is not None else None) + + if combined_item in mapping: + name = name.replace(combined_item, mapping[combined_item]) + elif item in mapping and mapping[item] not in name: name = name.replace(item, mapping[item]) return name, loaded_weight diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 003e9c84c..e78e8d62c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder - max_image_size = mm_encoder.mm_config.max_image_size - image_patch_size = mm_encoder.mm_config.image_patch_size + image_config = mm_encoder.mm_config if hasattr( + mm_encoder, "mm_config") else mm_encoder.image_config + + max_image_size = image_config.max_image_size + image_patch_size = image_config.image_patch_size return ((max_image_size // image_patch_size)**2) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fb5cc3ec0..42b45e10e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,7 +4,7 @@ import enum import json import os from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Literal, Optional, Type, Union import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, @@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], for key, value in elem.items(): key = config_mapping.get(key, key) config_dict[key] = recurse_elems(value) - return PretrainedConfig(**config_dict) + + return config_dict else: return elem @@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], config_dict["max_position_embeddings"] = config_dict.get( "max_position_embeddings", 128_000) + if config_dict.get("quantization") is not None: + quantization = config_dict.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config_dict["quantization_config"] = quantization_config + + config_type: Literal["text", + "multimodal"] = "multimodal" if config_dict.get( + "vision_encoder") is not None else "text" + if config_dict.get("moe") is not None: config_dict["architectures"] = ["MixtralForCausalLM"] else: config_dict["architectures"] = ["MistralForCausalLM"] - if config_dict.get("vision_encoder") is not None: + if config_type == "multimodal": multimodal_config = config_dict.pop("vision_encoder") config_dict = { @@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], config_dict.update(kwargs) - config = recurse_elems(config_dict) - return config + config_dict = recurse_elems(config_dict) + + # transform to HF config format + if config_type == "multimodal": + config_dict["text_config"] = PretrainedConfig( + **config_dict["text_config"]) + config_dict["vision_config"] = PretrainedConfig( + **config_dict["vision_config"]) + + return PretrainedConfig(**config_dict) def get_hf_image_processor_config( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7a1dba424..8d96fcd27 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: def find_tokenizer_file(files: List[str]): - file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + file_pattern = re.compile( + r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$") matched_files = [file for file in files if file_pattern.match(file)] if len(matched_files) > 1: -- GitLab From 24700c346bee5760f015bf41cdc6fd9ffb5d6aaf Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Feb 2025 15:32:32 -0800 Subject: [PATCH 046/253] [V1] Cache `uses_mrope` in GPUModelRunner (#12969) --- vllm/v1/worker/gpu_model_runner.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e0a096a91..fdbca70bd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -92,6 +92,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope # NOTE: Initialized input mapper is only used for processing dummy # multimodal data into multimodal kwargs for GPU memory profiling. @@ -147,7 +148,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): device=self.device) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. @@ -284,7 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] @@ -411,7 +412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.model_config.uses_mrope: + if self.uses_mrope: self._calc_mrope_positions(scheduler_output) # Get token indices. @@ -458,7 +459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - if self.model_config.uses_mrope: + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], @@ -817,13 +818,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_input_tokens] + else: + positions = self.positions[:num_input_tokens] # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - positions = self.mrope_positions[:, :num_input_tokens] \ - if self.model_config.uses_mrope \ - else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1001,10 +1003,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] with set_forward_context(None, self.vllm_config): - positions = self.mrope_positions[:, :num_tokens] \ - if self.model_config.uses_mrope \ - else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, positions=positions, -- GitLab From cf797aa856995a474eec310884f2a71a3826c0f3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Feb 2025 15:00:00 +0800 Subject: [PATCH 047/253] [core] port pynvml into vllm codebase (#12963) Signed-off-by: youkaichao --- .pre-commit-config.yaml | 20 +- requirements-cuda.txt | 1 - tests/utils.py | 5 +- vllm/third_party/__init__.py | 0 vllm/third_party/pynvml.py | 6139 ++++++++++++++++++++++++++++++++++ vllm/utils.py | 39 +- 6 files changed, 6169 insertions(+), 35 deletions(-) create mode 100644 vllm/third_party/__init__.py create mode 100644 vllm/third_party/pynvml.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 118451593..352eb2df0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,25 +8,28 @@ repos: - id: yapf args: [--in-place, --verbose] additional_dependencies: [toml] # TODO: Remove when yapf is upgraded + exclude: 'vllm/third_party/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: - id: ruff args: [--output-format, github] + exclude: 'vllm/third_party/.*' - repo: https://github.com/codespell-project/codespell rev: v2.4.0 hooks: - id: codespell - exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*' + exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*|vllm/third_party/.*' - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - id: isort + exclude: 'vllm/third_party/.*' - repo: https://github.com/pre-commit/mirrors-clang-format rev: v19.1.7 hooks: - id: clang-format - exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))' + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' types_or: [c++, cuda] args: [--style=file, --verbose] - repo: https://github.com/jackdewinter/pymarkdown @@ -34,10 +37,12 @@ repos: hooks: - id: pymarkdown args: [fix] + exclude: 'vllm/third_party/.*' - repo: https://github.com/rhysd/actionlint rev: v1.7.7 hooks: - id: actionlint + exclude: 'vllm/third_party/.*' - repo: local hooks: - id: mypy-local @@ -47,6 +52,7 @@ repos: types: [python] additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] stages: [pre-commit] # Don't run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 entry: tools/mypy.sh 1 "3.9" @@ -54,6 +60,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 entry: tools/mypy.sh 1 "3.10" @@ -61,6 +68,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 entry: tools/mypy.sh 1 "3.11" @@ -68,6 +76,7 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 entry: tools/mypy.sh 1 "3.12" @@ -75,16 +84,19 @@ repos: types: [python] additional_dependencies: *mypy_deps stages: [manual] # Only run in CI + exclude: 'vllm/third_party/.*' - id: shellcheck name: Lint shell scripts entry: tools/shellcheck.sh language: script types: [shell] + exclude: 'vllm/third_party/.*' - id: png-lint name: Lint PNG exports from excalidraw entry: tools/png-lint.sh language: script types: [png] + exclude: 'vllm/third_party/.*' - id: signoff-commit name: Sign-off Commit entry: bash @@ -97,17 +109,20 @@ repos: language: system verbose: true stages: [commit-msg] + exclude: 'vllm/third_party/.*' - id: check-spdx-header name: Check SPDX headers entry: python tools/check_spdx_header.py language: python types: [python] + exclude: 'vllm/third_party/.*' - id: suggestion name: Suggestion entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' language: system verbose: true pass_filenames: false + exclude: 'vllm/third_party/.*' - id: check-filenames name: Check for spaces in all filenames entry: bash @@ -117,3 +132,4 @@ repos: language: system always_run: true pass_filenames: false + exclude: 'vllm/third_party/.*' diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 78fa360f2..0e7217fb3 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -3,7 +3,6 @@ # Dependencies for NVIDIA GPUs ray[default] >= 2.9 -nvidia-ml-py >= 12.560.30 # for pynvml package torch == 2.5.1 torchaudio==2.5.1 # These must be updated alongside torch diff --git a/tests/utils.py b/tests/utils.py index 3b32052fe..f39cbe7ed 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,8 +46,9 @@ if current_platform.is_rocm(): finally: amdsmi_shut_down() elif current_platform.is_cuda(): - from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit, nvmlShutdown) + from vllm.third_party.pynvml import (nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, nvmlInit, + nvmlShutdown) @contextmanager def _nvml(): diff --git a/vllm/third_party/__init__.py b/vllm/third_party/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py new file mode 100644 index 000000000..0a4be23a0 --- /dev/null +++ b/vllm/third_party/pynvml.py @@ -0,0 +1,6139 @@ +# SPDX-License-Identifier: Apache-2.0 +# copied from https://pypi.org/project/nvidia-ml-py +# version 12.570.86 + +##### +# Copyright (c) 2011-2023, NVIDIA Corporation. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA Corporation nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +##### + +## +# Python bindings for the NVML library +## +from ctypes import * +from ctypes.util import find_library +from functools import wraps +import sys +import os +import threading +import string + +## C Type mappings ## +## Enums +_nvmlEnableState_t = c_uint +NVML_FEATURE_DISABLED = 0 +NVML_FEATURE_ENABLED = 1 + +_nvmlBrandType_t = c_uint +NVML_BRAND_UNKNOWN = 0 +NVML_BRAND_QUADRO = 1 +NVML_BRAND_TESLA = 2 +NVML_BRAND_NVS = 3 +NVML_BRAND_GRID = 4 # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_GEFORCE = 5 +NVML_BRAND_TITAN = 6 +NVML_BRAND_NVIDIA_VAPPS = 7 # NVIDIA Virtual Applications +NVML_BRAND_NVIDIA_VPC = 8 # NVIDIA Virtual PC +NVML_BRAND_NVIDIA_VCS = 9 # NVIDIA Virtual Compute Server +NVML_BRAND_NVIDIA_VWS = 10 # NVIDIA RTX Virtual Workstation +NVML_BRAND_NVIDIA_CLOUD_GAMING = 11 # NVIDIA Cloud Gaming +NVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_QUADRO_RTX = 12 +NVML_BRAND_NVIDIA_RTX = 13 +NVML_BRAND_NVIDIA = 14 +NVML_BRAND_GEFORCE_RTX = 15 # Unused +NVML_BRAND_TITAN_RTX = 16 # Unused +NVML_BRAND_COUNT = 17 + +_nvmlTemperatureThresholds_t = c_uint +NVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0 +NVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1 +NVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2 +NVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6 +NVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7 +NVML_TEMPERATURE_THRESHOLD_COUNT = 8 + +_nvmlTemperatureSensors_t = c_uint +NVML_TEMPERATURE_GPU = 0 +NVML_TEMPERATURE_COUNT = 1 + + +_nvmlComputeMode_t = c_uint +NVML_COMPUTEMODE_DEFAULT = 0 +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_PROHIBITED = 2 +NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 +NVML_COMPUTEMODE_COUNT = 4 + +_nvmlMemoryLocation_t = c_uint +NVML_MEMORY_LOCATION_L1_CACHE = 0 +NVML_MEMORY_LOCATION_L2_CACHE = 1 +NVML_MEMORY_LOCATION_DEVICE_MEMORY = 2 +NVML_MEMORY_LOCATION_DRAM = 2 +NVML_MEMORY_LOCATION_REGISTER_FILE = 3 +NVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4 +NVML_MEMORY_LOCATION_TEXTURE_SHM = 5 +NVML_MEMORY_LOCATION_CBU = 6 +NVML_MEMORY_LOCATION_SRAM = 7 +NVML_MEMORY_LOCATION_COUNT = 8 + +NVML_NVLINK_MAX_LINKS = 18 + +# For backwards compatibility, maintain the incorrectly-named "LANES" define +NVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS + +_nvmlNvLinkErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_REPLAY = 0 +NVML_NVLINK_ERROR_DL_RECOVERY = 1 +NVML_NVLINK_ERROR_DL_CRC_FLIT = 2 +NVML_NVLINK_ERROR_DL_CRC_DATA = 3 +NVML_NVLINK_ERROR_DL_ECC_DATA = 4 +NVML_NVLINK_ERROR_COUNT = 5 + +_nvmlNvLinkEccLaneErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_ECC_LANE0 = 0 +NVML_NVLINK_ERROR_DL_ECC_LANE1 = 1 +NVML_NVLINK_ERROR_DL_ECC_LANE2 = 2 +NVML_NVLINK_ERROR_DL_ECC_LANE3 = 3 +NVML_NVLINK_ERROR_DL_ECC_COUNT = 5 + +_nvmlNvLinkCapability_t = c_uint +NVML_NVLINK_CAP_P2P_SUPPORTED = 0 +NVML_NVLINK_CAP_SYSMEM_ACCESS = 1 +NVML_NVLINK_CAP_P2P_ATOMICS = 2 +NVML_NVLINK_CAP_SYSMEM_ATOMICS= 3 +NVML_NVLINK_CAP_SLI_BRIDGE = 4 +NVML_NVLINK_CAP_VALID = 5 +NVML_NVLINK_CAP_COUNT = 6 + +_nvmlNvLinkUtilizationCountPktTypes_t = c_uint +NVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1 +NVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2 +NVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4 +NVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8 +NVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10 +NVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20 +NVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40 +NVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80 +NVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF + +_nvmlNvLinkUtilizationCountUnits_t = c_uint +NVML_NVLINK_COUNTER_UNIT_CYCLES = 0 +NVML_NVLINK_COUNTER_UNIT_PACKETS = 1 +NVML_NVLINK_COUNTER_UNIT_BYTES = 2 +NVML_NVLINK_COUNTER_UNIT_RESERVED = 3 +NVML_NVLINK_COUNTER_UNIT_COUNT = 4 + +_nvmlNvLinkDeviceType_t = c_uint +NVML_NVLINK_DEVICE_TYPE_GPU = 0x00 +NVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01 +NVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02 +NVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF + +# These are deprecated, instead use _nvmlMemoryErrorType_t +_nvmlEccBitType_t = c_uint +NVML_SINGLE_BIT_ECC = 0 +NVML_DOUBLE_BIT_ECC = 1 +NVML_ECC_ERROR_TYPE_COUNT = 2 + +_nvmlEccCounterType_t = c_uint +NVML_VOLATILE_ECC = 0 +NVML_AGGREGATE_ECC = 1 +NVML_ECC_COUNTER_TYPE_COUNT = 2 + +_nvmlMemoryErrorType_t = c_uint +NVML_MEMORY_ERROR_TYPE_CORRECTED = 0 +NVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1 +NVML_MEMORY_ERROR_TYPE_COUNT = 2 + +_nvmlClockType_t = c_uint +NVML_CLOCK_GRAPHICS = 0 +NVML_CLOCK_SM = 1 +NVML_CLOCK_MEM = 2 +NVML_CLOCK_VIDEO = 3 +NVML_CLOCK_COUNT = 4 + +_nvmlClockId_t = c_uint +NVML_CLOCK_ID_CURRENT = 0 +NVML_CLOCK_ID_APP_CLOCK_TARGET = 1 +NVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2 +NVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3 +NVML_CLOCK_ID_COUNT = 4 + +_nvmlDriverModel_t = c_uint +NVML_DRIVER_WDDM = 0 +NVML_DRIVER_WDM = 1 +NVML_DRIVER_MCDM = 2 + +NVML_MAX_GPU_PERF_PSTATES = 16 + +_nvmlPstates_t = c_uint +NVML_PSTATE_0 = 0 +NVML_PSTATE_1 = 1 +NVML_PSTATE_2 = 2 +NVML_PSTATE_3 = 3 +NVML_PSTATE_4 = 4 +NVML_PSTATE_5 = 5 +NVML_PSTATE_6 = 6 +NVML_PSTATE_7 = 7 +NVML_PSTATE_8 = 8 +NVML_PSTATE_9 = 9 +NVML_PSTATE_10 = 10 +NVML_PSTATE_11 = 11 +NVML_PSTATE_12 = 12 +NVML_PSTATE_13 = 13 +NVML_PSTATE_14 = 14 +NVML_PSTATE_15 = 15 +NVML_PSTATE_UNKNOWN = 32 + +_nvmlInforomObject_t = c_uint +NVML_INFOROM_OEM = 0 +NVML_INFOROM_ECC = 1 +NVML_INFOROM_POWER = 2 +NVML_INFOROM_DEN = 3 +NVML_INFOROM_COUNT = 4 + +_nvmlReturn_t = c_uint +NVML_SUCCESS = 0 +NVML_ERROR_UNINITIALIZED = 1 +NVML_ERROR_INVALID_ARGUMENT = 2 +NVML_ERROR_NOT_SUPPORTED = 3 +NVML_ERROR_NO_PERMISSION = 4 +NVML_ERROR_ALREADY_INITIALIZED = 5 +NVML_ERROR_NOT_FOUND = 6 +NVML_ERROR_INSUFFICIENT_SIZE = 7 +NVML_ERROR_INSUFFICIENT_POWER = 8 +NVML_ERROR_DRIVER_NOT_LOADED = 9 +NVML_ERROR_TIMEOUT = 10 +NVML_ERROR_IRQ_ISSUE = 11 +NVML_ERROR_LIBRARY_NOT_FOUND = 12 +NVML_ERROR_FUNCTION_NOT_FOUND = 13 +NVML_ERROR_CORRUPTED_INFOROM = 14 +NVML_ERROR_GPU_IS_LOST = 15 +NVML_ERROR_RESET_REQUIRED = 16 +NVML_ERROR_OPERATING_SYSTEM = 17 +NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18 +NVML_ERROR_IN_USE = 19 +NVML_ERROR_MEMORY = 20 +NVML_ERROR_NO_DATA = 21 +NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22 +NVML_ERROR_INSUFFICIENT_RESOURCES = 23 +NVML_ERROR_FREQ_NOT_SUPPORTED = 24 +NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25 +NVML_ERROR_DEPRECATED = 26 +NVML_ERROR_NOT_READY = 27 +NVML_ERROR_GPU_NOT_FOUND = 28 +NVML_ERROR_INVALID_STATE = 29 +NVML_ERROR_UNKNOWN = 999 + +_nvmlFanState_t = c_uint +NVML_FAN_NORMAL = 0 +NVML_FAN_FAILED = 1 + +_nvmlFanControlPolicy_t = c_uint +NVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0 +NVML_FAN_POLICY_MANUAL = 1 + +_nvmlLedColor_t = c_uint +NVML_LED_COLOR_GREEN = 0 +NVML_LED_COLOR_AMBER = 1 + +_nvmlGpuOperationMode_t = c_uint +NVML_GOM_ALL_ON = 0 +NVML_GOM_COMPUTE = 1 +NVML_GOM_LOW_DP = 2 + +_nvmlPageRetirementCause_t = c_uint +NVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0 +NVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1 +NVML_PAGE_RETIREMENT_CAUSE_COUNT = 2 + +_nvmlRestrictedAPI_t = c_uint +NVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0 +NVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1 +NVML_RESTRICTED_API_COUNT = 2 + +_nvmlBridgeChipType_t = c_uint +NVML_BRIDGE_CHIP_PLX = 0 +NVML_BRIDGE_CHIP_BRO4 = 1 +NVML_MAX_PHYSICAL_BRIDGE = 128 + +_nvmlValueType_t = c_uint +NVML_VALUE_TYPE_DOUBLE = 0 +NVML_VALUE_TYPE_UNSIGNED_INT = 1 +NVML_VALUE_TYPE_UNSIGNED_LONG = 2 +NVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3 +NVML_VALUE_TYPE_SIGNED_LONG_LONG = 4 +NVML_VALUE_TYPE_SIGNED_INT = 5 +NVML_VALUE_TYPE_UNSIGNED_SHORT = 6 +NVML_VALUE_TYPE_COUNT = 7 + +_nvmlNvlinkVersion_t = c_uint +NVML_NVLINK_VERSION_INVALID = 0 +NVML_NVLINK_VERSION_1_0 = 1 +NVML_NVLINK_VERSION_2_0 = 2 +NVML_NVLINK_VERSION_2_2 = 3 +NVML_NVLINK_VERSION_3_0 = 4 +NVML_NVLINK_VERSION_3_1 = 5 +NVML_NVLINK_VERSION_4_0 = 6 +NVML_NVLINK_VERSION_5_0 = 7 + +_nvmlPerfPolicyType_t = c_uint +NVML_PERF_POLICY_POWER = 0 +NVML_PERF_POLICY_THERMAL = 1 +NVML_PERF_POLICY_SYNC_BOOST = 2 +NVML_PERF_POLICY_BOARD_LIMIT = 3 +NVML_PERF_POLICY_LOW_UTILIZATION = 4 +NVML_PERF_POLICY_RELIABILITY = 5 +NVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10 +NVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11 +NVML_PERF_POLICY_COUNT = 12 + +_nvmlEncoderQueryType_t = c_uint +NVML_ENCODER_QUERY_H264 = 0 +NVML_ENCODER_QUERY_HEVC = 1 +NVML_ENCODER_QUERY_AV1 = 2 +NVML_ENCODER_QUERY_UNKNOWN = 255 + +_nvmlFBCSessionType_t = c_uint +NVML_FBC_SESSION_TYPE_UNKNOWN = 0 +NVML_FBC_SESSION_TYPE_TOSYS = 1 +NVML_FBC_SESSION_TYPE_CUDA = 2 +NVML_FBC_SESSION_TYPE_VID = 3 +NVML_FBC_SESSION_TYPE_HWENC = 4 + +_nvmlDetachGpuState_t = c_uint +NVML_DETACH_GPU_KEEP = 0 +NVML_DETACH_GPU_REMOVE = 1 + +_nvmlPcieLinkState_t = c_uint +NVML_PCIE_LINK_KEEP = 0 +NVML_PCIE_LINK_SHUT_DOWN = 1 + +_nvmlSamplingType_t = c_uint +NVML_TOTAL_POWER_SAMPLES = 0 +NVML_GPU_UTILIZATION_SAMPLES = 1 +NVML_MEMORY_UTILIZATION_SAMPLES = 2 +NVML_ENC_UTILIZATION_SAMPLES = 3 +NVML_DEC_UTILIZATION_SAMPLES = 4 +NVML_PROCESSOR_CLK_SAMPLES = 5 +NVML_MEMORY_CLK_SAMPLES = 6 +NVML_MODULE_POWER_SAMPLES = 7 +NVML_JPG_UTILIZATION_SAMPLES = 8 +NVML_OFA_UTILIZATION_SAMPLES = 9 +NVML_SAMPLINGTYPE_COUNT = 10 + +_nvmlPcieUtilCounter_t = c_uint +NVML_PCIE_UTIL_TX_BYTES = 0 +NVML_PCIE_UTIL_RX_BYTES = 1 +NVML_PCIE_UTIL_COUNT = 2 + +_nvmlGpuTopologyLevel_t = c_uint +NVML_TOPOLOGY_INTERNAL = 0 +NVML_TOPOLOGY_SINGLE = 10 +NVML_TOPOLOGY_MULTIPLE = 20 +NVML_TOPOLOGY_HOSTBRIDGE = 30 +NVML_TOPOLOGY_NODE = 40 +NVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE +NVML_TOPOLOGY_SYSTEM = 50 + +_nvmlGpuP2PCapsIndex_t = c_uint +NVML_P2P_CAPS_INDEX_READ = 0, +NVML_P2P_CAPS_INDEX_WRITE = 1 +NVML_P2P_CAPS_INDEX_NVLINK =2 +NVML_P2P_CAPS_INDEX_ATOMICS = 3 +# +# NVML_P2P_CAPS_INDEX_PROP is deprecated. +# Use NVML_P2P_CAPS_INDEX_PCI instead. +# +NVML_P2P_CAPS_INDEX_PROP = 4 +NVML_P2P_CAPS_INDEX_PCI = 4 +NVML_P2P_CAPS_INDEX_UNKNOWN = 5 + +_nvmlGpuP2PStatus_t = c_uint +NVML_P2P_STATUS_OK = 0 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED +NVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2 +NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED =3 +NVML_P2P_STATUS_DISABLED_BY_REGKEY =4 +NVML_P2P_STATUS_NOT_SUPPORTED =5 +NVML_P2P_STATUS_UNKNOWN =6 + +_nvmlDeviceArchitecture_t = c_uint +NVML_DEVICE_ARCH_KEPLER = 2 +NVML_DEVICE_ARCH_MAXWELL = 3 +NVML_DEVICE_ARCH_PASCAL = 4 +NVML_DEVICE_ARCH_VOLTA = 5 +NVML_DEVICE_ARCH_TURING = 6 +NVML_DEVICE_ARCH_AMPERE = 7 +NVML_DEVICE_ARCH_ADA = 8 +NVML_DEVICE_ARCH_HOPPER = 9 +NVML_DEVICE_ARCH_BLACKWELL = 10 +NVML_DEVICE_ARCH_T23X = 11 +NVML_DEVICE_ARCH_UNKNOWN = 0xffffffff + +# PCI bus Types +_nvmlBusType_t = c_uint +NVML_BUS_TYPE_UNKNOWN = 0 +NVML_BUS_TYPE_PCI = 1 +NVML_BUS_TYPE_PCIE = 2 +NVML_BUS_TYPE_FPCI = 3 +NVML_BUS_TYPE_AGP = 4 + +_nvmlPowerSource_t = c_uint +NVML_POWER_SOURCE_AC = 0x00000000 +NVML_POWER_SOURCE_BATTERY = 0x00000001 +NVML_POWER_SOURCE_UNDERSIZED = 0x00000002 + +_nvmlAdaptiveClockInfoStatus_t = c_uint +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000 +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001 + +_nvmlClockLimitId_t = c_uint +NVML_CLOCK_LIMIT_ID_RANGE_START = 0xffffff00 +NVML_CLOCK_LIMIT_ID_TDP = 0xffffff01 +NVML_CLOCK_LIMIT_ID_UNLIMITED = 0xffffff02 + +_nvmlPcieLinkMaxSpeed_t = c_uint +NVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000 +NVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001 +NVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002 +NVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003 +NVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004 +NVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005 +NVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006 + +_nvmlPcieAtomicsCapability_t = c_uint +NVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01 +NVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02 +NVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04 +NVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08 +NVML_PCIE_ATOMICS_CAP_CAS32 = 0x10 +NVML_PCIE_ATOMICS_CAP_CAS64 = 0x20 +NVML_PCIE_ATOMICS_CAP_CAS128 = 0x40 +NVML_PCIE_ATOMICS_OPS_MAX = 7 + +_nvmlAffinityScope_t = c_uint +NVML_AFFINITY_SCOPE_NODE = 0 +NVML_AFFINITY_SCOPE_SOCKET = 1 + +_nvmlDeviceGpuRecoveryAction_t = c_uint +NVML_GPU_RECOVERY_ACTION_NONE = 0 +NVML_GPU_RECOVERY_ACTION_GPU_RESET = 1 +NVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2 +NVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3 +NVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4 + +# C preprocessor defined values +nvmlFlagDefault = 0 +nvmlFlagForce = 1 +NVML_INIT_FLAG_NO_GPUS = 1 +NVML_INIT_FLAG_NO_ATTACH = 2 + +NVML_MAX_GPC_COUNT = 32 + +# buffer size +NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16 +NVML_DEVICE_UUID_BUFFER_SIZE = 80 +NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96 +NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80 +NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80 +NVML_DEVICE_NAME_BUFFER_SIZE = 64 +NVML_DEVICE_NAME_V2_BUFFER_SIZE = 96 +NVML_DEVICE_SERIAL_BUFFER_SIZE = 30 +NVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16 +NVML_GRID_LICENSE_BUFFER_SIZE = 128 +NVML_VGPU_NAME_BUFFER_SIZE = 64 +NVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3 +NVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256 +NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256 +NVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = 0x14 # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH +NVML_PERF_MODES_BUFFER_SIZE = 2048 + +# Format strings +NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = "%04X:%02X:%02X.0" +NVML_DEVICE_PCI_BUS_ID_FMT = "%08X:%02X:%02X.0" + +NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1) +NVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1) + +''' + Field Identifiers. + + All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change. +''' +NVML_FI_DEV_ECC_CURRENT = 1 # Current ECC mode. 1=Active. 0=Inactive +NVML_FI_DEV_ECC_PENDING = 2 # Pending ECC mode. 1=Active. 0=Inactive + +#ECC Count Totals +NVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3 # Total single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4 # Total double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5 # Total single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6 # Total double bit aggregate (persistent) ECC errors +#Individual ECC locations +NVML_FI_DEV_ECC_SBE_VOL_L1 = 7 # L1 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L1 = 8 # L1 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_L2 = 9 # L2 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L2 = 10 # L2 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_DEV = 11 # Device memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_DEV = 12 # Device memory double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_REG = 13 # Register file single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_REG = 14 # Register file double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_TEX = 15 # Texture memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TEX = 16 # Texture memory double bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_CBU = 17 # CBU double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L1 = 18 # L1 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L1 = 19 # L1 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L2 = 20 # L2 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L2 = 21 # L2 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_DEV = 22 # Device memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_DEV = 23 # Device memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_REG = 24 # Register File single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_REG = 25 # Register File double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TEX = 26 # Texture memory single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TEX = 27 # Texture memory double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_CBU = 28 # CBU double bit aggregate ECC errors + +# Page Retirement +NVML_FI_DEV_RETIRED_SBE = 29 # Number of retired pages because of single bit errors +NVML_FI_DEV_RETIRED_DBE = 30 # Number of retired pages because of double bit errors +NVML_FI_DEV_RETIRED_PENDING = 31 # If any pages are pending retirement. 1=yes. 0=no. + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = 32 # NVLink flow control CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = 33 # NVLink flow control CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = 34 # NVLink flow control CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = 35 # NVLink flow control CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = 36 # NVLink flow control CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = 37 # NVLink flow control CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = 38 # NVLink flow control CRC Error Counter total for all Lanes + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = 39 # NVLink data CRC Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = 40 # NVLink data CRC Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = 41 # NVLink data CRC Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = 42 # NVLink data CRC Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = 43 # NVLink data CRC Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = 44 # NVLink data CRC Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = 45 # NvLink data CRC Error Counter total for all Lanes + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46 # NVLink Replay Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47 # NVLink Replay Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48 # NVLink Replay Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49 # NVLink Replay Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50 # NVLink Replay Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51 # NVLink Replay Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = 52 # NVLink Replay Error Counter total for all Lanes + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = 53 # NVLink Recovery Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = 54 # NVLink Recovery Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = 55 # NVLink Recovery Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = 56 # NVLink Recovery Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = 57 # NVLink Recovery Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = 58 # NVLink Recovery Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = 59 # NVLink Recovery Error Counter total for all Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = 60 # NVLink Bandwidth Counter for Counter Set 0, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = 61 # NVLink Bandwidth Counter for Counter Set 0, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = 62 # NVLink Bandwidth Counter for Counter Set 0, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = 63 # NVLink Bandwidth Counter for Counter Set 0, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = 64 # NVLink Bandwidth Counter for Counter Set 0, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = 65 # NVLink Bandwidth Counter for Counter Set 0, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = 66 # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = 67 # NVLink Bandwidth Counter for Counter Set 1, Lane 0 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = 68 # NVLink Bandwidth Counter for Counter Set 1, Lane 1 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = 69 # NVLink Bandwidth Counter for Counter Set 1, Lane 2 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = 70 # NVLink Bandwidth Counter for Counter Set 1, Lane 3 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = 71 # NVLink Bandwidth Counter for Counter Set 1, Lane 4 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = 72 # NVLink Bandwidth Counter for Counter Set 1, Lane 5 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = 73 # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes + +# Perf Policy Counters +NVML_FI_DEV_PERF_POLICY_POWER = 74 # Perf Policy Counter for Power Policy +NVML_FI_DEV_PERF_POLICY_THERMAL = 75 # Perf Policy Counter for Thermal Policy +NVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76 # Perf Policy Counter for Sync boost Policy +NVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77 # Perf Policy Counter for Board Limit +NVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = 78 # Perf Policy Counter for Low GPU Utilization Policy +NVML_FI_DEV_PERF_POLICY_RELIABILITY = 79 # Perf Policy Counter for Reliability Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = 80 # Perf Policy Counter for Total App Clock Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = 81 # Perf Policy Counter for Total Base Clocks Policy + +# Memory temperatures +NVML_FI_DEV_MEMORY_TEMP = 82 # Memory temperature for the device + +# Energy Counter +NVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = 83 # Total energy consumption for the GPU in mJ since the driver was last reloaded + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89 +NVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90 + +# NVLink Link Count +NVML_FI_DEV_NVLINK_LINK_COUNT = 91 + +# Page Retirement pending fields +NVML_FI_DEV_RETIRED_PENDING_SBE = 92 +NVML_FI_DEV_RETIRED_PENDING_DBE = 93 + +# PCIe replay and replay rollover counters +NVML_FI_DEV_PCIE_REPLAY_COUNTER = 94 +NVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95 + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = 96 # NVLink flow control CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = 97 # NVLink flow control CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = 98 # NVLink flow control CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = 99 # NVLink flow control CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = 100 # NVLink flow control CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = 101 # NVLink flow control CRC Error Counter for Lane 11 + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = 102 # NVLink data CRC Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = 103 # NVLink data CRC Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = 104 # NVLink data CRC Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = 105 # NVLink data CRC Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = 106 # NVLink data CRC Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = 107 # NVLink data CRC Error Counter for Lane 11 + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108 # NVLink Replay Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109 # NVLink Replay Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110 # NVLink Replay Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111 # NVLink Replay Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = 112 # NVLink Replay Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = 113 # NVLink Replay Error Counter for Lane 11 + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = 114 # NVLink Recovery Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = 115 # NVLink Recovery Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = 116 # NVLink Recovery Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = 117 # NVLink Recovery Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = 118 # NVLink Recovery Error Counter for Lane 10 +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = 119 # NVLink Recovery Error Counter for Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = 120 # NVLink Bandwidth Counter for Counter Set 0, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = 121 # NVLink Bandwidth Counter for Counter Set 0, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = 122 # NVLink Bandwidth Counter for Counter Set 0, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = 123 # NVLink Bandwidth Counter for Counter Set 0, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = 124 # NVLink Bandwidth Counter for Counter Set 0, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = 125 # NVLink Bandwidth Counter for Counter Set 0, Lane 11 + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = 126 # NVLink Bandwidth Counter for Counter Set 1, Lane 6 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = 127 # NVLink Bandwidth Counter for Counter Set 1, Lane 7 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = 128 # NVLink Bandwidth Counter for Counter Set 1, Lane 8 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = 129 # NVLink Bandwidth Counter for Counter Set 1, Lane 9 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = 130 # NVLink Bandwidth Counter for Counter Set 1, Lane 10 +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = 131 # NVLink Bandwidth Counter for Counter Set 1, Lane 11 + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137 + +# NVLink Throughput Counters +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138 # NVLink TX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139 # NVLink RX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140 # NVLink TX Data + protocol overhead in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141 # NVLink RX Data + protocol overhead in KiB + +# Row Remapper +NVML_FI_DEV_REMAPPED_COR = 142 +NVML_FI_DEV_REMAPPED_UNC = 143 +NVML_FI_DEV_REMAPPED_PENDING = 144 +NVML_FI_DEV_REMAPPED_FAILURE = 145 + +#Remote device NVLink ID +NVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146 + +# Number of NVLinks connected to NVSwitch +NVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147 + +# NvLink ECC Data Error Counters +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = 148 #< NVLink data ECC Error Counter for Link 0 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = 149 #< NVLink data ECC Error Counter for Link 1 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = 150 #< NVLink data ECC Error Counter for Link 2 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = 151 #< NVLink data ECC Error Counter for Link 3 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = 152 #< NVLink data ECC Error Counter for Link 4 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = 153 #< NVLink data ECC Error Counter for Link 5 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = 154 #< NVLink data ECC Error Counter for Link 6 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = 155 #< NVLink data ECC Error Counter for Link 7 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = 156 #< NVLink data ECC Error Counter for Link 8 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = 157 #< NVLink data ECC Error Counter for Link 9 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = 158 #< NVLink data ECC Error Counter for Link 10 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = 159 #< NVLink data ECC Error Counter for Link 11 +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = 160 #< NvLink data ECC Error Counter total for all Links + +NVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161 +NVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162 +NVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163 +NVML_FI_DEV_NVLINK_GET_SPEED = 164 +NVML_FI_DEV_NVLINK_GET_STATE = 165 +NVML_FI_DEV_NVLINK_GET_VERSION = 166 + +NVML_FI_DEV_NVLINK_GET_POWER_STATE = 167 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168 + +NVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169 + +NVML_FI_DEV_C2C_LINK_COUNT = 170 +NVML_FI_DEV_C2C_LINK_GET_STATUS = 171 +NVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172 + +NVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173 +NVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174 +NVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175 +NVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176 +NVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177 +NVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178 +NVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179 +NVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180 +NVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181 +NVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182 +NVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183 + +NVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184 + +NVML_FI_DEV_POWER_AVERAGE = 185 +NVML_FI_DEV_POWER_INSTANT = 186 +NVML_FI_DEV_POWER_MIN_LIMIT = 187 +NVML_FI_DEV_POWER_MAX_LIMIT = 188 +NVML_FI_DEV_POWER_DEFAULT_LIMIT = 189 +NVML_FI_DEV_POWER_CURRENT_LIMIT = 190 +NVML_FI_DEV_ENERGY = 191 +NVML_FI_DEV_POWER_REQUESTED_LIMIT = 192 + +NVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193 +NVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194 +NVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195 +NVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196 + +NVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197 +NVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198 + +NVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200 + +NVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201 +NVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202 +NVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203 +NVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204 +NVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206 +NVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207 +NVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208 +NVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209 +NVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210 +NVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211 +NVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212 + +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215 + +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219 +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = 224 # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_* +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225 + +NVML_FI_DEV_RESET_STATUS = 226 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_DRAIN_AND_RESET_STATUS = 227 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +NVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228 +NVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229 +NVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230 + +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250 +NVML_FI_PWR_SMOOTHING_ENABLED = 251 # Enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_PRIV_LVL = 252 # Current privilege level +NVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = 253 # Immediate ramp down enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254 # Applied TMP ceiling value +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255 # Applied TMP floor value +NVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256 # Max % TMP Floor value +NVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257 # Min % TMP Floor value +NVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = 258 # HW Circuitry % lifetime remaining +NVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259 # Max number of preset profiles +NVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = 261 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = 262 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = 263 # Ramp down hysteresis value in ms for a given profile +NVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264 # Active preset profile number +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = 265 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = 266 # Ramp up rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = 267 # Ramp down rate in mW/s for a given profile +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = 268 # Ramp down hysteresis value in ms for a given profile + +NVML_FI_MAX = 269 # One greater than the largest field ID defined above + +# NVML_FI_DEV_NVLINK_GET_STATE state enums +NVML_NVLINK_STATE_INACTIVE = 0x0 +NVML_NVLINK_STATE_ACTIVE = 0x1 +NVML_NVLINK_STATE_SLEEP = 0x2 + +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = 0 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = 1 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS + +## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU +NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = 1 # Device is associated with GPU-Passthorugh +NVML_GPU_VIRTUALIZATION_MODE_VGPU = 2 # Device is associated with vGPU inside virtual machine. +NVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = 3 # Device is associated with VGX hypervisor in vGPU mode +NVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = 4 # Device is associated with VGX hypervisor in vSGA mode + +## Lib loading ## +nvmlLib = None +libLoadLock = threading.Lock() +_nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown + +## vGPU Management +_nvmlVgpuTypeId_t = c_uint +_nvmlVgpuInstance_t = c_uint + +_nvmlVgpuVmIdType_t = c_uint +NVML_VGPU_VM_ID_DOMAIN_ID = 0 +NVML_VGPU_VM_ID_UUID = 1 + +_nvmlGridLicenseFeatureCode_t = c_uint +NVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0 +NVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1 +NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2 +NVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = 2 # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX. +NVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3 +NVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4 + +_nvmlGridLicenseExpiryStatus_t = c_uint8 +NVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = 0, # Expiry information not available +NVML_GRID_LICENSE_EXPIRY_INVALID = 1, # Invalid expiry or error fetching expiry +NVML_GRID_LICENSE_EXPIRY_VALID = 2, # Valid expiry +NVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = 3, # Expiry not applicable +NVML_GRID_LICENSE_EXPIRY_PERMANENT = 4, # Permanent expiry + +_nvmlVgpuCapability_t = c_uint +NVML_VGPU_CAP_NVLINK_P2P = 0 # vGPU P2P over NVLink is supported +NVML_VGPU_CAP_GPUDIRECT = 1 # GPUDirect capability is supported +NVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = 2 # vGPU profile cannot be mixed with other vGPU profiles in same VM +NVML_VGPU_CAP_EXCLUSIVE_TYPE = 3 # vGPU profile cannot run on a GPU alongside other profiles of different type +NVML_VGPU_CAP_EXCLUSIVE_SIZE = 4 # vGPU profile cannot run on a GPU alongside other profiles of different size +NVML_VGPU_CAP_COUNT = 5 + +_nvmlVgpuDriverCapability_t = c_uint +NVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = 0 # Supports mixing of different vGPU profiles within one guest VM +NVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1 # Supports FSR and warm update of vGPU host driver without terminating the running guest VM +NVML_VGPU_DRIVER_CAP_COUNT = 2 + +_nvmlDeviceVgpuCapability_t = c_uint +NVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0 # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes +NVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3 # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4 # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = 5 # Query whether the vGPU profiles on the GPU supports migration data streaming +NVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = 6 # Set/Get support of mini-quarter vGPU profiles +NVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = 7 # Set/Get support for compute media engine vGPU profiles +NVML_DEVICE_VGPU_CAP_WARM_UPDATE = 8 # Query whether the GPU supports FSR and warm update +NVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9 # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes +NVML_DEVICE_VGPU_CAP_COUNT = 10 + +_nvmlVgpuGuestInfoState_t = c_uint +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0 +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1 + +_nvmlVgpuVmCompatibility_t = c_uint +NVML_VGPU_VM_COMPATIBILITY_NONE = 0x0 +NVML_VGPU_VM_COMPATIBILITY_COLD = 0x1 +NVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2 +NVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4 +NVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8 + +_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint +NVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0 +NVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1 +NVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2 +NVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4 +NVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000 + +_nvmlHostVgpuMode_t = c_uint +NVML_HOST_VGPU_MODE_NON_SRIOV = 0 +NVML_HOST_VGPU_MODE_SRIOV = 1 + +_nvmlConfComputeGpusReadyState_t = c_uint +NVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0 +NVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1 + +_nvmlConfComputeGpuCaps_t = c_uint +NVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0 +NVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1 + +_nvmlConfComputeCpuCaps_t = c_uint +NVML_CC_SYSTEM_CPU_CAPS_NONE = 0 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1 +NVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4 + +_nvmlConfComputeDevToolsMode_t = c_uint +NVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0 +NVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1 + +NVML_CC_SYSTEM_MULTIGPU_NONE = 0 +NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1 + +NVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0 +NVML_CC_SYSTEM_ENVIRONMENT_SIM = 1 +NVML_CC_SYSTEM_ENVIRONMENT_PROD = 2 + +_nvmlConfComputeCcFeature_t = c_uint +NVML_CC_SYSTEM_FEATURE_DISABLED = 0 +NVML_CC_SYSTEM_FEATURE_ENABLED = 1 + +_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50 +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65 + +# GSP firmware +NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 + +class NVMLLibraryMismatchError(Exception): + pass + +## Error Checking ## +class NVMLError(Exception): + _valClassMapping = dict() + # List of currently known error codes + _errcode_to_string = { + NVML_ERROR_UNINITIALIZED: "Uninitialized", + NVML_ERROR_INVALID_ARGUMENT: "Invalid Argument", + NVML_ERROR_NOT_SUPPORTED: "Not Supported", + NVML_ERROR_NO_PERMISSION: "Insufficient Permissions", + NVML_ERROR_ALREADY_INITIALIZED: "Already Initialized", + NVML_ERROR_NOT_FOUND: "Not Found", + NVML_ERROR_INSUFFICIENT_SIZE: "Insufficient Size", + NVML_ERROR_INSUFFICIENT_POWER: "Insufficient External Power", + NVML_ERROR_DRIVER_NOT_LOADED: "Driver Not Loaded", + NVML_ERROR_TIMEOUT: "Timeout", + NVML_ERROR_IRQ_ISSUE: "Interrupt Request Issue", + NVML_ERROR_LIBRARY_NOT_FOUND: "NVML Shared Library Not Found", + NVML_ERROR_FUNCTION_NOT_FOUND: "Function Not Found", + NVML_ERROR_CORRUPTED_INFOROM: "Corrupted infoROM", + NVML_ERROR_GPU_IS_LOST: "GPU is lost", + NVML_ERROR_RESET_REQUIRED: "GPU requires restart", + NVML_ERROR_OPERATING_SYSTEM: "The operating system has blocked the request.", + NVML_ERROR_LIB_RM_VERSION_MISMATCH: "RM has detected an NVML/RM version mismatch.", + NVML_ERROR_MEMORY: "Insufficient Memory", + NVML_ERROR_UNKNOWN: "Unknown Error", + } + def __new__(typ, value): + ''' + Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details + ''' + if typ == NVMLError: + typ = NVMLError._valClassMapping.get(value, typ) + obj = Exception.__new__(typ) + obj.value = value + return obj + def __str__(self): + try: + if self.value not in NVMLError._errcode_to_string: + NVMLError._errcode_to_string[self.value] = str(nvmlErrorString(self.value)) + return NVMLError._errcode_to_string[self.value] + except NVMLError: + return "NVML Error with code %d" % self.value + def __eq__(self, other): + return self.value == other.value + +def nvmlExceptionClass(nvmlErrorCode): + if nvmlErrorCode not in NVMLError._valClassMapping: + raise ValueError('nvmlErrorCode %s is not valid' % nvmlErrorCode) + return NVMLError._valClassMapping[nvmlErrorCode] + +def _extractNVMLErrorsAsClasses(): + ''' + Generates a hierarchy of classes on top of NVMLError class. + + Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate + exceptions more easily. + + NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized + ''' + this_module = sys.modules[__name__] + nvmlErrorsNames = [x for x in dir(this_module) if x.startswith("NVML_ERROR_")] + for err_name in nvmlErrorsNames: + # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized + class_name = "NVMLError_" + string.capwords(err_name.replace("NVML_ERROR_", ""), "_").replace("_", "") + err_val = getattr(this_module, err_name) + def gen_new(val): + def new(typ): + obj = NVMLError.__new__(typ, val) + return obj + return new + new_error_class = type(class_name, (NVMLError,), {'__new__': gen_new(err_val)}) + new_error_class.__module__ = __name__ + setattr(this_module, class_name, new_error_class) + NVMLError._valClassMapping[err_val] = new_error_class +_extractNVMLErrorsAsClasses() + +def _nvmlCheckReturn(ret): + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + return ret + +## Function access ## +_nvmlGetFunctionPointer_cache = dict() # function pointers are cached to prevent unnecessary libLoadLock locking +def _nvmlGetFunctionPointer(name): + global nvmlLib + + if name in _nvmlGetFunctionPointer_cache: + return _nvmlGetFunctionPointer_cache[name] + + libLoadLock.acquire() + try: + # ensure library was loaded + if (nvmlLib == None): + raise NVMLError(NVML_ERROR_UNINITIALIZED) + try: + _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name) + return _nvmlGetFunctionPointer_cache[name] + except AttributeError: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +## Alternative object +# Allows the object to be printed +# Allows mismatched types to be assigned +# - like None when the Structure variant requires c_uint +class nvmlFriendlyObject(object): + def __init__(self, dictionary): + for x in dictionary: + setattr(self, x, dictionary[x]) + def __str__(self): + return self.__dict__.__str__() + +def nvmlStructToFriendlyObject(struct): + d = {} + for x in struct._fields_: + key = x[0] + value = getattr(struct, key) + # only need to convert from bytes if bytes, no need to check python version. + d[key] = value.decode() if isinstance(value, bytes) else value + obj = nvmlFriendlyObject(d) + return obj + +# pack the object so it can be passed to the NVML library +def nvmlFriendlyObjectToStruct(obj, model): + for x in model._fields_: + key = x[0] + value = obj.__dict__[key] + # any c_char_p in python3 needs to be bytes, default encoding works fine. + if sys.version_info >= (3,): + setattr(model, key, value.encode()) + else: + setattr(model, key, value) + return model + +## Unit structures +class struct_c_nvmlUnit_t(Structure): + pass # opaque handle +c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) + +class _PrintableStructure(Structure): + """ + Abstract class that produces nicer __str__ output than ctypes.Structure. + e.g. instead of: + >>> print str(obj) + + this class will print + class_name(field_name: formatted_value, field_name: formatted_value) + + _fmt_ dictionary of -> + e.g. class that has _field_ 'hex_value', c_uint could be formatted with + _fmt_ = {"hex_value" : "%08X"} + to produce nicer output. + Default fomratting string for all fields can be set with key "" like: + _fmt_ = {"" : "%d MHz"} # e.g all values are numbers in MHz. + If not set it's assumed to be just "%s" + + Exact format of returned str from this class is subject to change in the future. + """ + _fmt_ = {} + def __str__(self): + result = [] + for x in self._fields_: + key = x[0] + value = getattr(self, key) + fmt = "%s" + if key in self._fmt_: + fmt = self._fmt_[key] + elif "" in self._fmt_: + fmt = self._fmt_[""] + result.append(("%s: " + fmt) % (key, value)) + return self.__class__.__name__ + "(" + ", ".join(result) + ")" + + def __getattribute__(self, name): + res = super(_PrintableStructure, self).__getattribute__(name) + # need to convert bytes to unicode for python3 don't need to for python2 + # Python 2 strings are of both str and bytes + # Python 3 strings are not of type bytes + # ctypes should convert everything to the correct values otherwise + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + def __setattr__(self, name, value): + if isinstance(value, str): + # encoding a python2 string returns the same value, since python2 strings are bytes already + # bytes passed in python3 will be ignored. + value = value.encode() + super(_PrintableStructure, self).__setattr__(name, value) + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class c_nvmlC2cModeInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('isC2cEnabled', c_uint) + ] + +nvmlC2cModeInfo_v1 = 0x1000008; + +class c_nvmlLedState_t(_PrintableStructure): + _fields_ = [ + ('cause', c_char * 256), + ('color', _nvmlLedColor_t), + ] + +class c_nvmlPSUInfo_t(_PrintableStructure): + _fields_ = [ + ('state', c_char * 256), + ('current', c_uint), + ('voltage', c_uint), + ('power', c_uint), + ] + +class c_nvmlUnitFanInfo_t(_PrintableStructure): + _fields_ = [ + ('speed', c_uint), + ('state', _nvmlFanState_t), + ] + +class c_nvmlUnitFanSpeeds_t(_PrintableStructure): + _fields_ = [ + ('fans', c_nvmlUnitFanInfo_t * 24), + ('count', c_uint) + ] + +## Device structures +class struct_c_nvmlDevice_t(Structure): + pass # opaque handle +c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t) + +class nvmlPciInfoExt_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + ('pciSubSystemId', c_uint), + ('baseClass', c_uint), + ('subClass', c_uint), + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'version' : "0x%04X", + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + 'baseClass' : "0x%01X", + 'subClass' : "0x%01X", + } + +nvmlPciInfoExt_v1 = 0x1000040 + +# Legacy pciInfo used for _v1 and _v2 +class nvmlPciInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + ('reserved0', c_uint), + ('reserved1', c_uint), + ('reserved2', c_uint), + ('reserved3', c_uint), + ] + _fmt_ = { + 'domain' : "0x%04X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class nvmlPciInfo_t(_PrintableStructure): + _fields_ = [ + # Moved to the new busId location below + ('busIdLegacy', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ('domain', c_uint), + ('bus', c_uint), + ('device', c_uint), + ('pciDeviceId', c_uint), + + # Added in 2.285 + ('pciSubSystemId', c_uint), + # New busId replaced the long deprecated and reserved fields with a + # field of the same size in 9.0 + ('busId', c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + 'domain' : "0x%08X", + 'bus' : "0x%02X", + 'device' : "0x%02X", + 'pciDeviceId' : "0x%08X", + 'pciSubSystemId' : "0x%08X", + } + +class c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ("branch", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ] + +SystemDriverBranchInfo_v1 = 0x1000054 + +class c_nvmlExcludedDeviceInfo_t(_PrintableStructure): + _fields_ = [ + ('pci', nvmlPciInfo_t), + ('uuid', c_char * NVML_DEVICE_UUID_BUFFER_SIZE) + ] + +class nvmlNvLinkUtilizationControl_t(_PrintableStructure): + _fields_ = [ + ('units', _nvmlNvLinkUtilizationCountUnits_t), + ('pktfilter', _nvmlNvLinkUtilizationCountPktTypes_t), + ] + +class c_nvmlMemory_t(_PrintableStructure): + _fields_ = [ + ('total', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class c_nvmlMemory_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('total', c_ulonglong), + ('reserved', c_ulonglong), + ('free', c_ulonglong), + ('used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +nvmlMemory_v2 = 0x02000028 + +class c_nvmlBAR1Memory_t(_PrintableStructure): + _fields_ = [ + ('bar1Total', c_ulonglong), + ('bar1Free', c_ulonglong), + ('bar1Used', c_ulonglong), + ] + _fmt_ = {'': "%d B"} + +class nvmlClkMonFaultInfo_t(Structure): + _fields_ = [("clkApiDomain", c_uint), + ("clkDomainFaultMask", c_uint) + ] + +MAX_CLK_DOMAINS = 32 + +class nvmlClkMonStatus_t(Structure): + _fields_ = [("bGlobalStatus", c_uint), + ("clkMonListSize", c_uint), + ("clkMonList", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS) + ] + +# On Windows with the WDDM driver, usedGpuMemory is reported as None +# Code that processes this structure should check for None, I.E. +# +# if (info.usedGpuMemory == None): +# # TODO handle the error +# pass +# else: +# print("Using %d MiB of memory" % (info.usedGpuMemory / 1024 / 1024)) +# endif +# +# See NVML documentation for more information +class c_nvmlProcessInfo_v2_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ] + _fmt_ = {'usedGpuMemory': "%d B"} + +c_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t + +c_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t + +_nvmlProcessMode_t = c_uint +NVML_PROCESS_MODE_COMPUTE = 0 +NVML_PROCESS_MODE_GRAPHICS = 1 +NVML_PROCESS_MODE_MPS = 2 + +class c_nvmlProcessDetail_v1_t(Structure): + _fields_ = [ + ('pid', c_uint), + ('usedGpuMemory', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint), + ('usedGpuCcProtectedMemory', c_ulonglong), + ] + +class c_nvmlProcessDetailList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', _nvmlProcessMode_t), + ('numProcArrayEntries', c_uint), + ('procArray', POINTER(c_nvmlProcessDetail_v1_t)), + ] + _fmt_ = {'numProcArrayEntries': "%d B"} + +c_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t + +nvmlProcessDetailList_v1 = 0x1000018 + +class c_nvmlBridgeChipInfo_t(_PrintableStructure): + _fields_ = [ + ('type', _nvmlBridgeChipType_t), + ('fwVersion', c_uint), + ] + +class c_nvmlBridgeChipHierarchy_t(_PrintableStructure): + _fields_ = [ + ('bridgeCount', c_uint), + ('bridgeChipInfo', c_nvmlBridgeChipInfo_t * 128), + ] + +class c_nvmlEccErrorCounts_t(_PrintableStructure): + _fields_ = [ + ('l1Cache', c_ulonglong), + ('l2Cache', c_ulonglong), + ('deviceMemory', c_ulonglong), + ('registerFile', c_ulonglong), + ] + +class c_nvmlUtilization_t(_PrintableStructure): + _fields_ = [ + ('gpu', c_uint), + ('memory', c_uint), + ] + _fmt_ = {'': "%d %%"} + +# Added in 2.285 +class c_nvmlHwbcEntry_t(_PrintableStructure): + _fields_ = [ + ('hwbcId', c_uint), + ('firmwareVersion', c_char * 32), + ] + +class c_nvmlValue_t(Union): + _fields_ = [ + ('dVal', c_double), + ('uiVal', c_uint), + ('ulVal', c_ulong), + ('ullVal', c_ulonglong), + ('sllVal', c_longlong), + ('siVal', c_int), + ('usVal', c_ushort), + ] + +class c_nvmlSample_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('sampleValue', c_nvmlValue_t), + ] + +class c_nvmlViolationTime_t(_PrintableStructure): + _fields_ = [ + ('referenceTime', c_ulonglong), + ('violationTime', c_ulonglong), + ] + +class c_nvmlFieldValue_t(_PrintableStructure): + _fields_ = [ + ('fieldId', c_uint32), + ('scopeId', c_uint32), + ('timestamp', c_int64), + ('latencyUsec', c_int64), + ('valueType', _nvmlValueType_t), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_nvmlValue_t) + ] + +NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23 + +nvmlNvlinkSupportedBwModes_v1 = 0x100001c +class c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bwModes', c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES), + ('totalBwModes', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__(version=nvmlNvlinkSupportedBwModes_v1) + +nvmlNvlinkGetBwMode_v1 = 0x100000c +class c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bIsBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1) + +nvmlNvlinkSetBwMode_v1 = 0x100000c +class c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bSetBest', c_uint), + ('bwMode', c_uint8) + ] + + def __init__(self): + super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1) + +class c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('mode', c_uint), + ] + +VgpuHeterogeneousMode_v1 = 0x1000008 + +class c_nvmlVgpuPlacementId_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementId', c_uint), + ] + +VgpuPlacementId_v1 = 0x1000008 + +class c_nvmlVgpuPlacementList_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('count', c_uint), + ('placementSize', c_uint), + ('placementIds', POINTER(c_uint)), + ] + +VgpuPlacementList_v1 = 0x1000018 + +NVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0 +NVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1 + +class c_nvmlVgpuPlacementList_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('placementSize', c_uint), + ('count', c_uint), + ('placementIds', POINTER(c_uint)), + ('mode', c_uint), + ] + +VgpuPlacementList_v2 = 0x2000020 + +class c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('bar1Size', c_ulonglong), + ] + +VgpuTypeBar1Info_v1 = 0x1000010 + +class c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('timeStamp', c_ulonglong), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('smUtil', c_nvmlValue_t), + ('memUtil', c_nvmlValue_t), + ('encUtil', c_nvmlValue_t), + ('decUtil', c_nvmlValue_t), + ('jpgUtil', c_nvmlValue_t), + ('ofaUtil', c_nvmlValue_t), + ] + +class c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sampleValType', _nvmlValueType_t), + ('vgpuInstanceCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuUtilArray', POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)), + ] + +VgpuInstancesUtilizationInfo_v1 = 0x01000020 + +class c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('processName', c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ('timeStamp', c_ulonglong), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('vgpuProcessCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('vgpuProcUtilArray', POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)), + ] + +VgpuProcessesUtilizationInfo_v1 = 0x01000018 + +class nvmlVgpuRuntimeState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('size', c_ulonglong), + ] + +VgpuRuntimeState_v1 = 0x1000010 + +class c_nvmlVgpuLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +NVML_GRID_LICENSE_STATE_UNKNOWN = 0 +NVML_GRID_LICENSE_STATE_UNINITIALIZED = 1 +NVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2 +NVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3 +NVML_GRID_LICENSE_STATE_UNLICENSED = 4 +NVML_GRID_LICENSE_STATE_LICENSED = 5 + +class c_nvmlVgpuLicenseInfo_t(_PrintableStructure): + _fields_ = [ + ('isLicensed', c_uint8), + ('licenseExpiry', c_nvmlVgpuLicenseExpiry_t), + ('currentState', c_uint), + ] + +class c_nvmlEncoderSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('codecType', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFps', c_uint), + ('encodeLatency', c_uint), + ] + +class c_nvmlProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ('pid', c_uint), + ('timeStamp', c_ulonglong), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ] + +class c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('timeStamp', c_ulonglong), + ('pid', c_uint), + ('smUtil', c_uint), + ('memUtil', c_uint), + ('encUtil', c_uint), + ('decUtil', c_uint), + ('jpgUtil', c_uint), + ('ofaUtil', c_uint), + ] + +class c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('processSamplesCount', c_uint), + ('lastSeenTimeStamp', c_ulonglong), + ('procUtilArray', POINTER(c_nvmlProcessUtilizationInfo_v1_t)), + ] + +ProcessesUtilizationInfo_v1 = 0x01000018 + +class c_nvmlGridLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ('year', c_uint32), + ('month', c_uint16), + ('day', c_uint16), + ('hour', c_uint16), + ('min', c_uint16), + ('sec', c_uint16), + ('status', c_uint8), + ] + +class c_nvmlGridLicensableFeature_v4_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ('licenseExpiry', c_nvmlGridLicenseExpiry_t), + ] + +class c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v3_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('featureEnabled', c_uint), + ] + +class c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_v2_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ('productName', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlGridLicensableFeature_t(_PrintableStructure): + _fields_ = [ + ('featureCode', _nvmlGridLicenseFeatureCode_t), + ('featureState', c_uint), + ('licenseInfo', c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + +class c_nvmlGridLicensableFeatures_t(_PrintableStructure): + _fields_ = [ + ('isGridLicenseSupported', c_int), + ('licensableFeaturesCount', c_uint), + ('gridLicensableFeatures', c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT), + ] + +class c_nvmlMarginTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('marginTemperature', c_int), + ] + +nvmlMarginTemperature_v1 = 0x1000008 + +## Event structures +class struct_c_nvmlEventSet_t(Structure): + pass # opaque handle +c_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t) + +nvmlEventTypeSingleBitEccError = 0x0000000000000001 +nvmlEventTypeDoubleBitEccError = 0x0000000000000002 +nvmlEventTypePState = 0x0000000000000004 +nvmlEventTypeXidCriticalError = 0x0000000000000008 +nvmlEventTypeClock = 0x0000000000000010 +nvmlEventTypePowerSourceChange = 0x0000000000000080 +nvmlEventMigConfigChange = 0x0000000000000100 +nvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200 +nvmlEventTypeDramRetirementEvent = 0x0000000000000400 +nvmlEventTypeDramRetirementFailure = 0x0000000000000800 +nvmlEventTypeNonFatalPoisonError = 0x0000000000001000 +nvmlEventTypeFatalPoisonError = 0x0000000000002000 +nvmlEventTypeGpuUnavailableError = 0x0000000000004000 +nvmlEventTypeGpuRecoveryAction = 0x0000000000008000 +nvmlEventTypeNone = 0x0000000000000000 +nvmlEventTypeAll = ( + nvmlEventTypeNone + | nvmlEventTypeSingleBitEccError + | nvmlEventTypeDoubleBitEccError + | nvmlEventTypePState + | nvmlEventTypeClock + | nvmlEventTypePowerSourceChange + | nvmlEventTypeXidCriticalError + | nvmlEventMigConfigChange + | nvmlEventTypeSingleBitEccErrorStorm + | nvmlEventTypeDramRetirementEvent + | nvmlEventTypeDramRetirementFailure + | nvmlEventTypeNonFatalPoisonError + | nvmlEventTypeFatalPoisonError + | nvmlEventTypeGpuUnavailableError + | nvmlEventTypeGpuRecoveryAction + ) + +## Clock Event Reasons defines +nvmlClocksEventReasonGpuIdle = 0x0000000000000001 +nvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting +nvmlClocksEventReasonSwPowerCap = 0x0000000000000004 +nvmlClocksEventReasonHwSlowdown = 0x0000000000000008 +nvmlClocksEventReasonSyncBoost = 0x0000000000000010 +nvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksEventReasonNone = 0x0000000000000000 +nvmlClocksEventReasonAll = ( + nvmlClocksEventReasonNone | + nvmlClocksEventReasonGpuIdle | + nvmlClocksEventReasonApplicationsClocksSetting | + nvmlClocksEventReasonSwPowerCap | + nvmlClocksEventReasonHwSlowdown | + nvmlClocksEventReasonSyncBoost | + nvmlClocksEventReasonSwThermalSlowdown | + nvmlClocksEventReasonHwThermalSlowdown | + nvmlClocksEventReasonHwPowerBrakeSlowdown | + nvmlClocksEventReasonDisplayClockSetting + ) + +## Following have been deprecated +nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 +nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting +nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004 +nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008 +nvmlClocksThrottleReasonSyncBoost = 0x0000000000000010 +nvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksThrottleReasonNone = 0x0000000000000000 +nvmlClocksThrottleReasonAll = ( + nvmlClocksThrottleReasonNone | + nvmlClocksThrottleReasonGpuIdle | + nvmlClocksThrottleReasonApplicationsClocksSetting | + nvmlClocksThrottleReasonSwPowerCap | + nvmlClocksThrottleReasonHwSlowdown | + nvmlClocksThrottleReasonSyncBoost | + nvmlClocksThrottleReasonSwThermalSlowdown | + nvmlClocksThrottleReasonHwThermalSlowdown | + nvmlClocksThrottleReasonHwPowerBrakeSlowdown | + nvmlClocksThrottleReasonDisplayClockSetting + ) + +class c_nvmlEventData_t(_PrintableStructure): + _fields_ = [ + ('device', c_nvmlDevice_t), + ('eventType', c_ulonglong), + ('eventData', c_ulonglong), + ('gpuInstanceId', c_uint), + ('computeInstanceId', c_uint) + ] + _fmt_ = {'eventType': "0x%08X"} + +class c_nvmlAccountingStats_t(_PrintableStructure): + _fields_ = [ + ('gpuUtilization', c_uint), + ('memoryUtilization', c_uint), + ('maxMemoryUsage', c_ulonglong), + ('time', c_ulonglong), + ('startTime', c_ulonglong), + ('isRunning', c_uint), + ('reserved', c_uint * 5) + ] + +class c_nvmlVgpuVersion_t(Structure): + _fields_ = [("minVersion", c_uint), + ("maxVersion", c_uint) + ] + +class c_nvmlVgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("guestInfoState", _nvmlVgpuGuestInfoState_t), + ("guestDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("reserved", c_uint * 6), + ("vgpuVirtualizationCaps", c_uint), + ("guestVgpuVersion", c_uint), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuMetadata_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("revision", c_uint), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("pgpuVirtualizationCaps", c_uint), + ("reserved", c_uint * 5), + ("hostSupportedVgpuRange", c_nvmlVgpuVersion_t), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + ] + +class c_nvmlVgpuPgpuCompatibility_t(Structure): + _fields_ = [("vgpuVmCompatibility", _nvmlVgpuVmCompatibility_t), + ("compatibilityLimitCode", _nvmlVgpuPgpuCompatibilityLimitCode_t) + ] + +## vGPU scheduler policy defines +NVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0 +NVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1 +NVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2 +NVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3 + +## Supported vGPU scheduler policy count +NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3 + +NVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200 + +NVML_VGPU_SCHEDULER_ARR_DEFAULT = 0 +NVML_VGPU_SCHEDULER_ARR_DISABLE = 1 +NVML_VGPU_SCHEDULER_ARR_ENABLE = 2 + +class c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedData_t), + ] + +class c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure): + _fields_ = [ + ('timestamp', c_ulonglong), + ('timeRunTotal', c_ulonglong), + ('timeRun', c_ulonglong), + ('swRunlistId', c_uint), + ('targetTimeSlice', c_ulonglong), + ('cumulativePreemptionTime', c_ulonglong), + ] + +class c_nvmlVgpuSchedulerLog_t(_PrintableStructure): + _fields_ = [ + ('engineId', c_uint), + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ('entriesCount', c_uint), + ('logEntries', c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES), + ] + +class c_nvmlVgpuSchedulerGetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('arrMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerParams_t), + ] + +class c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure): + _fields_ = [ + ('avgFactor', c_uint), + ('frequency', c_uint), + ] + +class c_nvmlVgpuSchedSetData_t(_PrintableStructure): + _fields_ = [ + ('timeslice', c_uint), + ] + +class c_nvmlVgpuSchedulerSetParams_t(Union): + _fields_ = [ + ('vgpuSchedDataWithARR', c_nvmlVgpuSchedSetDataWithARR_t), + ('vgpuSchedData', c_nvmlVgpuSchedSetData_t), + ] + +class c_nvmlVgpuSchedulerSetState_t(_PrintableStructure): + _fields_ = [ + ('schedulerPolicy', c_uint), + ('enableARRMode', c_uint), + ('schedulerParams', c_nvmlVgpuSchedulerSetParams_t), + ] + +class c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure): + _fields_ = [ + ('supportedSchedulers', c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT), + ('maxTimeslice', c_uint), + ('minTimeslice', c_uint), + ('isArrModeSupported', c_uint), + ('maxFrequencyForARR', c_uint), + ('minFrequencyForARR', c_uint), + ('maxAvgFactorForARR', c_uint), + ('minAvgFactorForARR', c_uint), + ] + +class c_nvmlFBCStats_t(Structure): + _fields_ = [("sessionsCount", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint) + ] + +class c_nvmlFBCSession_t(_PrintableStructure): + _fields_ = [ + ('sessionId', c_uint), + ('pid', c_uint), + ('vgpuInstance', _nvmlVgpuInstance_t), + ('displayOrdinal', c_uint), + ('sessionType', c_uint), + ('sessionFlags', c_uint), + ('hMaxResolution', c_uint), + ('vMaxResolution', c_uint), + ('hResolution', c_uint), + ('vResolution', c_uint), + ('averageFPS', c_uint), + ('averageLatency', c_uint), + ] + +NVML_DEVICE_MIG_DISABLE = 0x0 +NVML_DEVICE_MIG_ENABLE = 0x1 + +NVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA +NVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB +NVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC +NVML_GPU_INSTANCE_PROFILE_COUNT = 0xD + +class c_nvmlGpuInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlGpuInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +nvmlGpuInstanceProfileInfo_v2 = 0x02000098 + +class c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__(version=nvmlGpuInstanceProfileInfo_v2) + +class c_nvmlGpuInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlGpuInstancePlacement_t) + ] + +class struct_c_nvmlGpuInstance_t(Structure): + pass # opaque handle +c_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t) + +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8 + +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0 +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1 + +class c_nvmlComputeInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), + ("size", c_uint) + ] + +class c_nvmlComputeInstanceProfileInfo_t(Structure): + _fields_ = [("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint) + ] + +nvmlComputeInstanceProfileInfo_v2 = 0x02000088 + +class c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE) + ] + + def __init__(self): + super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__(version=nvmlComputeInstanceProfileInfo_v2) + +class c_nvmlComputeInstanceInfo_t(Structure): + _fields_ = [("device", c_nvmlDevice_t), + ("gpuInstance", c_nvmlGpuInstance_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlComputeInstancePlacement_t) + ] + +NVML_MAX_GPU_UTILIZATIONS = 8 +NVML_GPU_UTILIZATION_DOMAIN_GPU = 0 +NVML_GPU_UTILIZATION_DOMAIN_FB = 1 +NVML_GPU_UTILIZATION_DOMAIN_VID = 2 +NVML_GPU_UTILIZATION_DOMAIN_BUS = 3 +class c_nvmlGpuDynamicPstatesUtilization_t(Structure): + _fields_ = [("bIsPresent", c_uint, 1), + ("percentage", c_uint), + ("incThreshold", c_uint), + ("decThreshold", c_uint)] +class c_nvmlGpuDynamicPstatesInfo_t(Structure): + _fields_ = [("flags", c_uint), + ("utilization", c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS)] + +NVML_MAX_THERMAL_SENSORS_PER_GPU = 3 + +NVML_THERMAL_TARGET_NONE = 0 +NVML_THERMAL_TARGET_GPU = 1 +NVML_THERMAL_TARGET_MEMORY = 2 +NVML_THERMAL_TARGET_POWER_SUPPLY = 4 +NVML_THERMAL_TARGET_BOARD = 8 +NVML_THERMAL_TARGET_VCD_BOARD = 9 +NVML_THERMAL_TARGET_VCD_INLET = 10 +NVML_THERMAL_TARGET_VCD_OUTLET = 11 +NVML_THERMAL_TARGET_ALL = 15 +NVML_THERMAL_TARGET_UNKNOWN = -1 + +NVML_THERMAL_CONTROLLER_NONE = 0 +NVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1 +NVML_THERMAL_CONTROLLER_ADM1032 = 2 +NVML_THERMAL_CONTROLLER_ADT7461 = 3 +NVML_THERMAL_CONTROLLER_MAX6649 = 4 +NVML_THERMAL_CONTROLLER_MAX1617 = 5 +NVML_THERMAL_CONTROLLER_LM99 = 6 +NVML_THERMAL_CONTROLLER_LM89 = 7 +NVML_THERMAL_CONTROLLER_LM64 = 8 +NVML_THERMAL_CONTROLLER_G781 = 9 +NVML_THERMAL_CONTROLLER_ADT7473 = 10 +NVML_THERMAL_CONTROLLER_SBMAX6649 = 11 +NVML_THERMAL_CONTROLLER_VBIOSEVT = 12 +NVML_THERMAL_CONTROLLER_OS = 13 +NVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14 +NVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15 +NVML_THERMAL_CONTROLLER_MAX6649R = 16 +NVML_THERMAL_CONTROLLER_ADT7473S = 17 +NVML_THERMAL_CONTROLLER_UNKNOWN = -1 + +class c_nvmlGpuThermalSensor_t(Structure): + _fields_ = [("controller", c_int), + ("defaultMinTemp", c_int), + ("defaultMaxTemp", c_int), + ("currentTemp", c_int), + ("target", c_int)] +class c_nvmlGpuThermalSettings_t(Structure): + _fields_ = [("count", c_uint), + ("sensor", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU)] + +_nvmlCoolerControl_t = c_uint +NVML_THERMAL_COOLER_SIGNAL_NONE = 0 +NVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1 +NVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2 +NVML_THERMAL_COOLER_SIGNAL_COUNT = 3 + +_nvmlCoolerTarget_t = c_uint +NVML_THERMAL_COOLER_TARGET_NONE = (1 << 0) +NVML_THERMAL_COOLER_TARGET_GPU = (1 << 1) +NVML_THERMAL_COOLER_TARGET_MEMORY = (1 << 2) +NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = (1 << 3) +NVML_THERMAL_COOLER_TARGET_GPU_RELATED = (NVML_THERMAL_COOLER_TARGET_GPU | NVML_THERMAL_COOLER_TARGET_MEMORY | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY) + +class c_nvmlCoolerInfo_t(_PrintableStructure): + _fields_ = [("version", c_uint), + ("index", c_uint), + ("coolerControlType", _nvmlCoolerControl_t), + ("coolerTarget", _nvmlCoolerTarget_t) + ] + +nvmlCoolerInfo_v1 = 0x1000010 + +def nvmlDeviceGetCoolerInfo(handle): + c_coolerInfo = c_nvmlCoolerInfo_t() + c_coolerInfo.version = nvmlCoolerInfo_v1 + c_coolerInfo.index = 0 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCoolerInfo") + ret = fn(handle, byref(c_coolerInfo)) + _nvmlCheckReturn(ret) + return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget] + +class struct_c_nvmlComputeInstance_t(Structure): + pass # opaque handle +c_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t) + +class c_nvmlDeviceAttributes(Structure): + _fields_ = [("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("gpuInstanceSliceCount", c_uint), + ("computeInstanceSliceCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + +class c_nvmlRowRemapperHistogramValues(Structure): + _fields_ = [("max", c_uint), + ("high", c_uint), + ("partial", c_uint), + ("low", c_uint), + ("none", c_uint) + ] + +NVML_GPU_CERT_CHAIN_SIZE = 0x1000 +NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400 +NVML_CC_GPU_CEC_NONCE_SIZE = 0x20 +NVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000 +NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000 +NVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0 +NVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1 + +class c_nvmlConfComputeSystemState_t(Structure): + _fields_ = [('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ] + +nvmlSystemConfComputeSettings_v1 = 0x1000014 + +class c_nvmlSystemConfComputeSettings_v1_t(Structure): + _fields_ = [('version', c_uint), + ('environment', c_uint), + ('ccFeature', c_uint), + ('devToolsMode', c_uint), + ('multiGpuMode', c_uint), + ] + def __init__(self): + super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__(version=nvmlSystemConfComputeSettings_v1) + +class c_nvmlConfComputeSystemCaps_t(Structure): + _fields_ = [('cpuCaps', c_uint), + ('gpusCaps', c_uint), + ] + +class c_nvmlConfComputeMemSizeInfo_t(Structure): + _fields_ = [('protectedMemSizeKib', c_ulonglong), + ('unprotectedMemSizeKib', c_ulonglong), + ] + +class c_nvmlConfComputeGpuCertificate_t(Structure): + _fields_ = [('certChainSize', c_uint), + ('attestationCertChainSize', c_uint), + ('certChain', c_uint8 * NVML_GPU_CERT_CHAIN_SIZE), + ('attestationCertChain', c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE), + ] + +class c_nvmlConfComputeGpuAttestationReport_t(Structure): + _fields_ = [('isCecAttestationReportPresent', c_uint), + ('attestationReportSize', c_uint), + ('cecAttestationReportSize', c_uint), + ('nonce', c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE), + ('attestationReport', c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE), + ('cecAttestationReport', c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE), + ] + +class c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('maxAttackerAdvantage', c_ulong), + ] +ConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010 + +class c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure): + _fields_ = [('version', c_uint), + ('attackerAdvantage', c_ulong), + ] +ConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010 + + +## string/bytes conversion for ease of use +def convertStrBytes(func): + ''' + In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + ---- + Returned from function: b'returned string' + Returned to caller: 'returned string' + ''' + @wraps(func) + def wrapper(*args, **kwargs): + # encoding a str returns bytes in python 2 and 3 + args = [arg.encode() if isinstance(arg, str) else arg for arg in args] + res = func(*args, **kwargs) + # In python 2, str and bytes are the same + # In python 3, str is unicode and should be decoded. + # Ctypes handles most conversions, this only effects c_char and char arrays. + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + if sys.version_info >= (3,): + return wrapper + return func + +def throwOnVersionMismatch(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except NVMLError_FunctionNotFound: + raise NVMLLibraryMismatchError("Unversioned function called and the " + "pyNVML version does not match the NVML lib version. " + "Either use matching pyNVML and NVML lib versions or " + "use a versioned function such as " + func.__name__ + "_v2") + return wrapper + +## C function wrappers ## +def nvmlInitWithFlags(flags): + _LoadNvmlLibrary() + + # + # Initialize the library + # + fn = _nvmlGetFunctionPointer("nvmlInitWithFlags") + ret = fn(flags) + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + _nvmlLib_refcount += 1 + libLoadLock.release() + return None + +def nvmlInit(): + nvmlInitWithFlags(0) + return None + +def _LoadNvmlLibrary(): + ''' + Load the library if it isn't loaded already + ''' + global nvmlLib + + if (nvmlLib == None): + # lock to ensure only one caller loads the library + libLoadLock.acquire() + + try: + # ensure the library still isn't loaded + if (nvmlLib == None): + try: + if (sys.platform[:3] == "win"): + # cdecl calling convention + try: + # Check for nvml.dll in System32 first for DCH drivers + nvmlLib = CDLL(os.path.join(os.getenv("WINDIR", "C:/Windows"), "System32/nvml.dll")) + except OSError as ose: + # If nvml.dll is not found in System32, it should be in ProgramFiles + # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll + nvmlLib = CDLL(os.path.join(os.getenv("ProgramFiles", "C:/Program Files"), "NVIDIA Corporation/NVSMI/nvml.dll")) + else: + # assume linux + nvmlLib = CDLL("libnvidia-ml.so.1") + except OSError as ose: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + if (nvmlLib == None): + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + +def nvmlShutdown(): + # + # Leave the library loaded, but shutdown the interface + # + fn = _nvmlGetFunctionPointer("nvmlShutdown") + ret = fn() + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + if (0 < _nvmlLib_refcount): + _nvmlLib_refcount -= 1 + libLoadLock.release() + return None + +# Added in 2.285 +@convertStrBytes +def nvmlErrorString(result): + fn = _nvmlGetFunctionPointer("nvmlErrorString") + fn.restype = c_char_p # otherwise return is an int + ret = fn(result) + return ret + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetNVMLVersion(): + c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetNVMLVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlSystemGetCudaDriverVersion(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +def nvmlSystemGetCudaDriverVersion_v2(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion_v2") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetProcessName(pid): + c_name = create_string_buffer(1024) + fn = _nvmlGetFunctionPointer("nvmlSystemGetProcessName") + ret = fn(c_uint(pid), c_name, c_uint(1024)) + _nvmlCheckReturn(ret) + return c_name.value + +@convertStrBytes +def nvmlSystemGetDriverVersion(): + c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlSystemGetHicVersion(): + c_count = c_uint(0) + hics = None + fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") + + # get the count + ret = fn(byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # If there are no hics + if (c_count.value == 0): + return [] + + hic_array = c_nvmlHwbcEntry_t * c_count.value + hics = hic_array() + ret = fn(byref(c_count), hics) + _nvmlCheckReturn(ret) + return hics + +def nvmlSystemGetDriverBranch(): + c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0) + c_branchInfo.version = SystemDriverBranchInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverBranch") + ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_branchInfo + +## Unit get functions +def nvmlUnitGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetHandleByIndex(index): + c_index = c_uint(index) + unit = c_nvmlUnit_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetHandleByIndex") + ret = fn(c_index, byref(unit)) + _nvmlCheckReturn(ret) + return unit + +def nvmlUnitGetUnitInfo(unit): + c_info = c_nvmlUnitInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetUnitInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetLedState(unit): + c_state = c_nvmlLedState_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetLedState") + ret = fn(unit, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlUnitGetPsuInfo(unit): + c_info = c_nvmlPSUInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetPsuInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlUnitGetTemperature(unit, type): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetTemperature") + ret = fn(unit, c_uint(type), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlUnitGetFanSpeedInfo(unit): + c_speeds = c_nvmlUnitFanSpeeds_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetFanSpeedInfo") + ret = fn(unit, byref(c_speeds)) + _nvmlCheckReturn(ret) + return c_speeds + +# added to API +def nvmlUnitGetDeviceCount(unit): + c_count = c_uint(0) + # query the unit to determine device count + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = NVML_SUCCESS + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlUnitGetDevices(unit): + c_count = c_uint(nvmlUnitGetDeviceCount(unit)) + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return c_devices + +## Device get functions +def nvmlDeviceGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetHandleByIndex(index): + c_index = c_uint(index) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByIndex_v2") + ret = fn(c_index, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleBySerial(serial): + c_serial = c_char_p(serial) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleBySerial") + ret = fn(c_serial, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByUUID(uuid): + c_uuid = c_char_p(uuid) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByUUID") + ret = fn(c_uuid, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetHandleByPciBusId(pciBusId): + c_busId = c_char_p(pciBusId) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByPciBusId_v2") + ret = fn(c_busId, byref(device)) + _nvmlCheckReturn(ret) + return device + +@convertStrBytes +def nvmlDeviceGetName(handle): + c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetName") + ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_name.value + +class c_nvmlDevicePerfModes_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDevicePerfModes_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetPerformanceModes(handle): + perfModes = c_nvmlDevicePerfModes_v1_t() + perfModes.version = nvmlDevicePerfModes_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceModes") + ret = fn(handle, byref(perfModes)) + _nvmlCheckReturn(ret) + return perfModes.str + +class c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('str', c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + +nvmlDeviceCurrentClockFreqs_v1 = 0x1000804 + +@convertStrBytes +def nvmlDeviceGetCurrentClockFreqs(handle): + currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t() + currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClockFreqs") + ret = fn(handle, byref(currentClockFreqs)) + _nvmlCheckReturn(ret) + return currentClockFreqs.str + +def nvmlDeviceGetBoardId(handle): + c_id = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId") + ret = fn(handle, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + +def nvmlDeviceGetMultiGpuBoard(handle): + c_multiGpu = c_uint(); + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard") + ret = fn(handle, byref(c_multiGpu)) + _nvmlCheckReturn(ret) + return c_multiGpu.value + +def nvmlDeviceGetBrand(handle): + c_type = _nvmlBrandType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBrand") + ret = fn(handle, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetC2cModeInfoV1(handle): + c_info = c_nvmlC2cModeInfo_v1_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetC2cModeInfoV") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetC2cModeInfoV(handle): + return nvmlDeviceGetC2cModeInfoV1(handle) + +@convertStrBytes +def nvmlDeviceGetBoardPartNumber(handle): + c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardPartNumber") + ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_part_number.value + +@convertStrBytes +def nvmlDeviceGetSerial(handle): + c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") + ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_serial.value + +def nvmlDeviceGetModuleId(handle, moduleId=c_uint()): + isReference = type(moduleId) is not c_uint + moduleIdRef = moduleId if isReference else byref(moduleId) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetModuleId") + ret = fn(handle, moduleIdRef) + if isReference: + return ret + else: + _nvmlCheckReturn(ret) + return moduleId.value + +def nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope): + affinity_array = c_ulonglong * nodeSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryAffinity") + ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinityWithinScope") + ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceGetCpuAffinity(handle, cpuSetSize): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinity") + ret = fn(handle, cpuSetSize, byref(c_affinity)) + _nvmlCheckReturn(ret) + return c_affinity + +def nvmlDeviceSetCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNumaNodeId(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumaNodeId") + node = c_int() + ret = fn(handle, byref(node)) + _nvmlCheckReturn(ret) + return node.value + +def nvmlDeviceGetMinorNumber(handle): + c_minor_number = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinorNumber") + ret = fn(handle, byref(c_minor_number)) + _nvmlCheckReturn(ret) + return c_minor_number.value + +@convertStrBytes +def nvmlDeviceGetUUID(handle): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") + ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlDeviceGetInforomVersion(handle, infoRomObject): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") + ret = fn(handle, _nvmlInforomObject_t(infoRomObject), + c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +@convertStrBytes +def nvmlDeviceGetInforomImageVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomImageVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 4.304 +def nvmlDeviceGetInforomConfigurationChecksum(handle): + c_checksum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomConfigurationChecksum") + ret = fn(handle, byref(c_checksum)) + _nvmlCheckReturn(ret) + return c_checksum.value + +# Added in 4.304 +def nvmlDeviceValidateInforom(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetLastBBXFlushTime(handle): + c_timestamp = c_ulonglong() + c_durationUs = c_ulong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetLastBBXFlushTime") + ret = fn(handle, byref(c_timestamp), byref(c_durationUs)) + _nvmlCheckReturn(ret) + return [c_timestamp.value, c_durationUs.value] + +def nvmlDeviceGetDisplayMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetDisplayActive(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetPersistenceMode(handle): + c_state = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") + ret = fn(handle, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlDeviceGetPciInfoExt(handle, c_info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfoExt") + ret = fn(handle, c_info) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetPciInfo_v3(handle): + c_info = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v3") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlDeviceGetPciInfo(handle): + return nvmlDeviceGetPciInfo_v3(handle) + +def nvmlDeviceGetClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 2.285 +def nvmlDeviceGetMaxClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetMaxCustomerBoostClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxCustomerBoostClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +def nvmlDeviceGetClock(handle, type, id): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClock") + ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 5.319 +def nvmlDeviceGetDefaultApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + +# Added in 4.304 +def nvmlDeviceGetSupportedMemoryClocks(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 4.304 +def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no clocks + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFanSpeed(handle): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed") + ret = fn(handle, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetFanSpeed_v2(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed_v2") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +class c_nvmlFanSpeedInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('fan', c_uint), + ('speed', c_uint), + ] + +nvmlFanSpeedInfo_v1 = 0x100000C + +def nvmlDeviceGetFanSpeedRPM(handle): + c_fanSpeed = c_nvmlFanSpeedInfo_t() + c_fanSpeed.fan = 0 + c_fanSpeed.version = nvmlFanSpeedInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeedRPM") + ret = fn(handle, byref(c_fanSpeed)) + _nvmlCheckReturn(ret) + return c_fanSpeed.speed + +def nvmlDeviceGetTargetFanSpeed(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTargetFanSpeed") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetNumFans(device): + c_numFans = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumFans") + ret = fn(device, byref(c_numFans)) + _nvmlCheckReturn(ret) + return c_numFans.value + +def nvmlDeviceSetDefaultFanSpeed_v2(handle, index): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultFanSpeed_v2"); + ret = fn(handle, index) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()): + isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint) + minSpeedRef = minSpeed if isReference else byref(minSpeed) + maxSpeedRef = maxSpeed if isReference else byref(maxSpeed) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxFanSpeed") + ret = fn(handle, minSpeedRef, maxSpeedRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value] + +def nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()): + isReference = type(fanControlPolicy) is not c_uint + fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanControlPolicy_v2") + ret = fn(handle, fan, fanControlPolicyRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else fanControlPolicy.value + +def nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanControlPolicy") + ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('sensorType', _nvmlTemperatureSensors_t), + ('temperature', c_int), + ] +nvmlTemperature_v1 = 0x100000C + +def nvmlDeviceGetTemperatureV1(handle, sensor): + c_temp = c_nvmlTemperature_v1_t() + c_temp.version = nvmlTemperature_v1 + c_temp.sensorType = _nvmlTemperatureSensors_t(sensor) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureV") + ret = fn(handle, byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.temperature + +def nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1): + if version == nvmlTemperature_v1: + return nvmlDeviceGetTemperatureV1(handle, sensor) + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + +# DEPRECATED use nvmlDeviceGetTemperatureV instead +def nvmlDeviceGetTemperature(handle, sensor): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") + ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceGetTemperatureThreshold(handle, threshold): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + +def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): + c_temp = c_uint() + c_temp.value = temp + fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetMarginTemperature(handle): + c_marginTempInfo = c_nvmlMarginTemperature_v1_t() + c_marginTempInfo.version = nvmlMarginTemperature_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMarginTemperature") + ret = fn(handle, byref(c_marginTempInfo)) + _nvmlCheckReturn(ret) + return c_marginTempInfo.marginTemperature + +# DEPRECATED use nvmlDeviceGetPerformanceState +def nvmlDeviceGetPowerState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPerformanceState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + +def nvmlDeviceGetPowerManagementMode(handle): + c_pcapMode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementMode") + ret = fn(handle, byref(c_pcapMode)) + _nvmlCheckReturn(ret) + return c_pcapMode.value + +def nvmlDeviceGetPowerManagementLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +# Added in 4.304 +def nvmlDeviceGetPowerManagementLimitConstraints(handle): + c_minLimit = c_uint() + c_maxLimit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimitConstraints") + ret = fn(handle, byref(c_minLimit), byref(c_maxLimit)) + _nvmlCheckReturn(ret) + return [c_minLimit.value, c_maxLimit.value] + +# Added in 4.304 +def nvmlDeviceGetPowerManagementDefaultLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementDefaultLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 331 +def nvmlDeviceGetEnforcedPowerLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEnforcedPowerLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + +def nvmlDeviceGetPowerUsage(handle): + c_watts = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerUsage") + ret = fn(handle, byref(c_watts)) + _nvmlCheckReturn(ret) + return c_watts.value + +def nvmlDeviceGetTotalEnergyConsumption(handle): + c_millijoules = c_uint64() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEnergyConsumption") + ret = fn(handle, byref(c_millijoules)) + _nvmlCheckReturn(ret) + return c_millijoules.value + +# Added in 4.304 +def nvmlDeviceGetGpuOperationMode(handle): + c_currState = _nvmlGpuOperationMode_t() + c_pendingState = _nvmlGpuOperationMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuOperationMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# Added in 4.304 +def nvmlDeviceGetCurrentGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[0] + +# Added in 4.304 +def nvmlDeviceGetPendingGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[1] + +def nvmlDeviceGetMemoryInfo(handle, version=None): + if not version: + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") + else: + c_memory = c_nvmlMemory_v2_t() + c_memory.version = version + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo_v2") + ret = fn(handle, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetBAR1MemoryInfo(handle): + c_bar1_memory = c_nvmlBAR1Memory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBAR1MemoryInfo") + ret = fn(handle, byref(c_bar1_memory)) + _nvmlCheckReturn(ret) + return c_bar1_memory + +def nvmlDeviceGetComputeMode(handle): + c_mode = _nvmlComputeMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceGetCudaComputeCapability(handle): + c_major = c_int() + c_minor = c_int() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") + ret = fn(handle, byref(c_major), byref(c_minor)) + _nvmlCheckReturn(ret) + return (c_major.value, c_minor.value) + +def nvmlDeviceGetEccMode(handle): + c_currState = _nvmlEnableState_t() + c_pendingState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEccMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + +# added to API +def nvmlDeviceGetCurrentEccMode(handle): + return nvmlDeviceGetEccMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingEccMode(handle): + return nvmlDeviceGetEccMode(handle)[1] + +def nvmlDeviceGetDefaultEccMode(handle): + c_defaultState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultEccMode") + ret = fn(handle, byref(c_defaultState)) + _nvmlCheckReturn(ret) + return [c_defaultState.value] + +def nvmlDeviceGetTotalEccErrors(handle, errorType, counterType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter +def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType): + c_counts = c_nvmlEccErrorCounts_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDetailedEccErrors") + ret = fn(handle, _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), byref(c_counts)) + _nvmlCheckReturn(ret) + return c_counts + +# Added in 4.304 +def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryErrorCounter") + ret = fn(handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + _nvmlMemoryLocation_t(locationType), + byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetUtilizationRates(handle): + c_util = c_nvmlUtilization_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") + ret = fn(handle, byref(c_util)) + _nvmlCheckReturn(ret) + return c_util + +def nvmlDeviceGetEncoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetDecoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDecoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetJpgUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetJpgUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetOfaUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetOfaUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + +def nvmlDeviceGetPcieReplayCounter(handle): + c_replay = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieReplayCounter") + ret = fn(handle, byref(c_replay)) + _nvmlCheckReturn(ret) + return c_replay.value + +def nvmlDeviceGetDriverModel(handle): + c_currModel = _nvmlDriverModel_t() + c_pendingModel = _nvmlDriverModel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDriverModel") + ret = fn(handle, byref(c_currModel), byref(c_pendingModel)) + _nvmlCheckReturn(ret) + return [c_currModel.value, c_pendingModel.value] + +# added to API +def nvmlDeviceGetCurrentDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[0] + +# added to API +def nvmlDeviceGetPendingDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[1] + +# Added in 2.285 +@convertStrBytes +def nvmlDeviceGetVbiosVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVbiosVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetComputeRunningProcesses(handle): + return nvmlDeviceGetComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetGraphicsRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +@throwOnVersionMismatch +def nvmlDeviceGetGraphicsRunningProcesses(handle): + return nvmlDeviceGetGraphicsRunningProcesses_v3(handle) + +@throwOnVersionMismatch +def nvmlDeviceGetMPSComputeRunningProcesses(handle): + return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle) + +def nvmlDeviceGetMPSComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + # oversize the array incase more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetRunningProcessDetailList(handle, version, mode): + c_processDetailList = c_nvmlProcessDetailList_t() + c_processDetailList.version = version + c_processDetailList.mode = mode + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRunningProcessDetailList") + + # first call to get the size + ret = fn(handle, byref(c_processDetailList)) + if (ret == NVML_SUCCESS): + # special case, no running processes + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries + c_processDetailList.procArray = cast((c_procs)(), POINTER(c_nvmlProcessDetail_v1_t)) + + # make the call again + ret = fn(handle, byref(c_processDetailList)) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_processDetailList.numProcArrayEntries): + # use an alternative struct for this object + obj = c_processDetailList.procArray[i] + if (obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuMemory = None + if (obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + obj.usedGpuCcProtectedMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetAutoBoostedClocksEnabled(handle): + c_isEnabled = _nvmlEnableState_t() + c_defaultIsEnabled = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAutoBoostedClocksEnabled") + ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled)) + _nvmlCheckReturn(ret) + return [c_isEnabled.value, c_defaultIsEnabled.value] + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +## Set functions +def nvmlUnitSetLedState(unit, color): + fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") + ret = fn(unit, _nvmlLedColor_t(color)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetPersistenceMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetComputeMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") + ret = fn(handle, _nvmlComputeMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetEccMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearEccErrorCounts(handle, counterType): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") + ret = fn(handle, _nvmlEccCounterType_t(counterType)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetDriverModel(handle, model): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") + ret = fn(handle, _nvmlDriverModel_t(model)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) + _nvmlCheckReturn(ret) + return None + #Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + +def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") + ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetGpuLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") + ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetMemoryLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()): + isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t + c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClkMonStatus") + ret = fn(handle, c_clkMonInfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_clkMonInfo + +# Added in 4.304 +def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") + ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceResetApplicationsClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetPowerManagementLimit(handle, limit): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") + ret = fn(handle, c_uint(limit)) + _nvmlCheckReturn(ret) + return None + +# Added in 4.304 +def nvmlDeviceSetGpuOperationMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") + ret = fn(handle, _nvmlGpuOperationMode_t(mode)) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlEventSetCreate(): + fn = _nvmlGetFunctionPointer("nvmlEventSetCreate") + eventSet = c_nvmlEventSet_t() + ret = fn(byref(eventSet)) + _nvmlCheckReturn(ret) + return eventSet + +# Added in 2.285 +def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): + fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") + ret = fn(handle, c_ulonglong(eventTypes), eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 2.285 +def nvmlDeviceGetSupportedEventTypes(handle): + c_eventTypes = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedEventTypes") + ret = fn(handle, byref(c_eventTypes)) + _nvmlCheckReturn(ret) + return c_eventTypes.value + +# raises NVML_ERROR_TIMEOUT exception on timeout +def nvmlEventSetWait_v2(eventSet, timeoutms): + fn = _nvmlGetFunctionPointer("nvmlEventSetWait_v2") + data = c_nvmlEventData_t() + ret = fn(eventSet, byref(data), c_uint(timeoutms)) + _nvmlCheckReturn(ret) + return data + +def nvmlEventSetWait(eventSet, timeoutms): + return nvmlEventSetWait_v2(eventSet, timeoutms) + +# Added in 2.285 +def nvmlEventSetFree(eventSet): + fn = _nvmlGetFunctionPointer("nvmlEventSetFree") + ret = fn(eventSet) + _nvmlCheckReturn(ret) + return None + +# Added in 3.295 +def nvmlDeviceOnSameBoard(handle1, handle2): + fn = _nvmlGetFunctionPointer("nvmlDeviceOnSameBoard") + onSameBoard = c_int() + ret = fn(handle1, handle2, byref(onSameBoard)) + _nvmlCheckReturn(ret) + return (onSameBoard.value != 0) + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + +def nvmlDeviceGetGpuMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + +# Added in 4.304 +def nvmlDeviceGetSupportedClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetSupportedClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 4.304 +def nvmlDeviceGetCurrentClocksThrottleReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +def nvmlDeviceGetCurrentClocksEventReasons(handle): + c_reasons= c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + +# Added in 5.319 +def nvmlDeviceGetIndex(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIndex") + c_index = c_uint() + ret = fn(handle, byref(c_index)) + _nvmlCheckReturn(ret) + return c_index.value + +# Added in 5.319 +def nvmlDeviceGetAccountingMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlDeviceSetAccountingMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceClearAccountingPids(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetAccountingStats(handle, pid): + stats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingStats") + ret = fn(handle, c_uint(pid), byref(stats)) + _nvmlCheckReturn(ret) + if (stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value): + # special case for WDDM on Windows, see comment above + stats.maxMemoryUsage = None + return stats + +def nvmlDeviceGetAccountingPids(handle): + count = c_uint(nvmlDeviceGetAccountingBufferSize(handle)) + pids = (c_uint * count.value)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") + ret = fn(handle, byref(count), pids) + _nvmlCheckReturn(ret) + return list(map(int, pids[0:count.value])) + +def nvmlDeviceGetAccountingBufferSize(handle): + bufferSize = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingBufferSize") + ret = fn(handle, byref(bufferSize)) + _nvmlCheckReturn(ret) + return int(bufferSize.value) + +def nvmlDeviceGetRetiredPages(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + ret = fn(device, c_source, byref(c_count), c_pages) + _nvmlCheckReturn(ret) + return list(map(int, c_pages[0:c_count.value])) + +def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages_v2") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if ((ret != NVML_SUCCESS) and + (ret != NVML_ERROR_INSUFFICIENT_SIZE)): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + times_array = c_ulonglong * c_count.value + c_times = times_array() + ret = fn(device, c_source, byref(c_count), c_pages, c_times) + _nvmlCheckReturn(ret) + return [ { 'address': int(c_pages[i]), 'timestamp': int(c_times[i]) } for i in range(c_count.value) ]; + +def nvmlDeviceGetRetiredPagesPendingStatus(device): + c_pending = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPagesPendingStatus") + ret = fn(device, byref(c_pending)) + _nvmlCheckReturn(ret) + return int(c_pending.value) + +def nvmlDeviceGetAPIRestriction(device, apiType): + c_permission = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAPIRestriction") + ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission)) + _nvmlCheckReturn(ret) + return int(c_permission.value) + +def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") + ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetBridgeChipInfo(handle): + bridgeHierarchy = c_nvmlBridgeChipHierarchy_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBridgeChipInfo") + ret = fn(handle, byref(bridgeHierarchy)) + _nvmlCheckReturn(ret) + return bridgeHierarchy + +def nvmlDeviceGetSamples(device, sampling_type, timeStamp): + c_sampling_type = _nvmlSamplingType_t(sampling_type) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_count = c_uint(0) + c_sample_value_type = _nvmlValueType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") + + ## First Call gets the size + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), None) + + # Stop if this fails + if (ret != NVML_SUCCESS): + raise NVMLError(ret) + + sampleArray = c_sample_count.value * c_nvmlSample_t + c_samples = sampleArray() + ret = fn(device, c_sampling_type, c_time_stamp, byref(c_sample_value_type), byref(c_sample_count), c_samples) + _nvmlCheckReturn(ret) + return (c_sample_value_type.value, c_samples[0:c_sample_count.value]) + +def nvmlDeviceGetViolationStatus(device, perfPolicyType): + c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType) + c_violTime = c_nvmlViolationTime_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") + + ## Invoke the method to get violation time + ret = fn(device, c_perfPolicy_type, byref(c_violTime)) + _nvmlCheckReturn(ret) + return c_violTime + +def nvmlDeviceGetPcieThroughput(device, counter): + c_util = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") + ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util)) + _nvmlCheckReturn(ret) + return c_util.value + +def nvmlSystemGetTopologyGpuSet(cpuNumber): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlSystemGetTopologyGpuSet") + + # First call will get the size + ret = fn(cpuNumber, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(cpuNumber, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyNearestGpus(device, level): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyNearestGpus") + + # First call will get the size + ret = fn(device, level, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(device, level, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0:c_count.value]) + +def nvmlDeviceGetTopologyCommonAncestor(device1, device2): + c_level = _nvmlGpuTopologyLevel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyCommonAncestor") + ret = fn(device1, device2, byref(c_level)) + _nvmlCheckReturn(ret) + return c_level.value + +def nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter): + c_rxcounter = c_ulonglong() + c_txcounter = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationCounter") + ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter)) + _nvmlCheckReturn(ret) + return (c_rxcounter.value, c_txcounter.value) + +def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): + fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") + ret = fn(device, link, counter, freeze) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") + ret = fn(device, link, counter) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(control), reset) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): + c_control = nvmlNvLinkUtilizationControl_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(c_control)) + _nvmlCheckReturn(ret) + return c_control + +def nvmlDeviceGetNvLinkCapability(device, link, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkCapability") + ret = fn(device, link, capability, byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetNvLinkErrorCounter(device, link, counter): + c_result = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkErrorCounter") + ret = fn(device, link, counter, byref(c_result)) + _nvmlCheckReturn(ret) + return c_result.value + +def nvmlDeviceResetNvLinkErrorCounters(device, link): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") + ret = fn(device, link) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetNvLinkRemotePciInfo(device, link): + c_pci = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemotePciInfo_v2") + ret = fn(device, link, byref(c_pci)) + _nvmlCheckReturn(ret) + return c_pci + +def nvmlDeviceGetNvLinkRemoteDeviceType(handle, link): + c_type = _nvmlNvLinkDeviceType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemoteDeviceType") + ret = fn(handle, link, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + +def nvmlDeviceGetNvLinkState(device, link): + c_isActive = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkState") + ret = fn(device, link, byref(c_isActive)) + _nvmlCheckReturn(ret) + return c_isActive.value + +def nvmlDeviceGetNvLinkVersion(device, link): + c_version = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkVersion") + ret = fn(device, link, byref(c_version)) + _nvmlCheckReturn(ret) + return c_version.value + +def nvmlDeviceModifyDrainState(pciInfo, newState): + fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") + ret = fn(pointer(pciInfo), newState) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceQueryDrainState(pciInfo): + c_newState = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceQueryDrainState") + ret = fn(pointer(pciInfo), byref(c_newState)) + _nvmlCheckReturn(ret) + return c_newState.value + +def nvmlDeviceRemoveGpu(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceDiscoverGpus(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + +def nvmlDeviceGetFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceClearFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceClearFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + +def nvmlDeviceGetVirtualizationMode(handle): + c_virtualization_mode = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVirtualizationMode") + ret = fn(handle, byref(c_virtualization_mode)) + _nvmlCheckReturn(ret) + return c_virtualization_mode.value + +def nvmlDeviceSetVirtualizationMode(handle, virtualization_mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVirtualizationMode") + return fn(handle, virtualization_mode) + +def nvmlDeviceGetVgpuHeterogeneousMode(handle): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return c_vgpuHeterogeneousMode.mode + +def nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + c_vgpuHeterogeneousMode.mode = heterogeneous_mode + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlVgpuInstanceGetPlacementId(vgpuInstance): + c_placement = c_nvmlVgpuPlacementId_v1_t(0) + c_placement.version = VgpuPlacementId_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetPlacementId") + ret = fn(vgpuInstance, byref(c_placement)) + _nvmlCheckReturn(ret) + return c_placement.placementId + +def nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + c_vgpu_placements.mode = mode + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeSupportedPlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeCreatablePlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + +def nvmlGetVgpuDriverCapabilities(capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuDriverCapabilities") + ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceGetVgpuCapabilities(handle, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + +def nvmlDeviceSetVgpuCapabilities(handle, capability, state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetCreatableVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCreatableVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no supported vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): + c_profile_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGpuInstanceProfileId") + ret = fn(vgpuTypeId, byref(c_profile_id)) + _nvmlCheckReturn(ret) + return (c_profile_id.value) + +@convertStrBytes +def nvmlVgpuTypeGetClass(vgpuTypeId): + c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetClass") + ret = fn(vgpuTypeId, c_class, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_class.value + +@convertStrBytes +def nvmlVgpuTypeGetName(vgpuTypeId): + c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetName") + ret = fn(vgpuTypeId, c_name, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_name.value + +def nvmlVgpuTypeGetDeviceID(vgpuTypeId): + c_device_id = c_ulonglong(0) + c_subsystem_id = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetDeviceID") + ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id)) + _nvmlCheckReturn(ret) + return (c_device_id.value, c_subsystem_id.value) + +def nvmlVgpuTypeGetFramebufferSize(vgpuTypeId): + c_fb_size = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFramebufferSize") + ret = fn(vgpuTypeId, byref(c_fb_size)) + _nvmlCheckReturn(ret) + return c_fb_size.value + +def nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId): + c_num_heads = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetNumDisplayHeads") + ret = fn(vgpuTypeId, byref(c_num_heads)) + _nvmlCheckReturn(ret) + return c_num_heads.value + +def nvmlVgpuTypeGetResolution(vgpuTypeId): + c_xdim = c_uint(0) + c_ydim = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetResolution") + ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim)) + _nvmlCheckReturn(ret) + return (c_xdim.value, c_ydim.value) + +@convertStrBytes +def nvmlVgpuTypeGetLicense(vgpuTypeId): + c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetLicense") + ret = fn(vgpuTypeId, c_license, c_buffer_size) + _nvmlCheckReturn(ret) + return c_license.value + +def nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId): + c_frl_config = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFrameRateLimit") + ret = fn(vgpuTypeId, byref(c_frl_config)) + _nvmlCheckReturn(ret) + return c_frl_config.value + +def nvmlVgpuTypeGetGspHeapSize(vgpuTypeId): + c_gsp_heap = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGspHeapSize") + ret = fn(vgpuTypeId, byref(c_gsp_heap)) + _nvmlCheckReturn(ret) + return c_gsp_heap.value + +def nvmlVgpuTypeGetFbReservation(vgpuTypeId): + c_fb_reservation = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFbReservation") + ret = fn(vgpuTypeId, byref(c_fb_reservation)) + _nvmlCheckReturn(ret) + return c_fb_reservation.value + +def nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance): + c_runtime_state = nvmlVgpuRuntimeState_v1_t() + c_runtime_state.version = VgpuRuntimeState_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetRuntimeStateSize") + ret = fn(vgpuInstance, byref(c_runtime_state)) + _nvmlCheckReturn(ret) + return c_runtime_state + +def nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + return c_max_instances.value + +def nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId): + c_max_instances_per_vm = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstancesPerVm") + ret = fn(vgpuTypeId, byref(c_max_instances_per_vm)) + _nvmlCheckReturn(ret) + return c_max_instances_per_vm.value + +def nvmlVgpuTypeGetBAR1Info(vgpuTypeId): + c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0) + c_bar1Info.version = VgpuTypeBar1Info_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetBAR1Info") + ret = fn(vgpuTypeId, byref(c_bar1Info)) + _nvmlCheckReturn(ret) + return c_bar1Info + +def nvmlDeviceGetActiveVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetActiveVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value + c_vgpu_instances = vgpu_instance_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_instances[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + +@convertStrBytes +def nvmlVgpuInstanceGetVmID(vgpuInstance): + c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + c_vm_id_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmID") + ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type)) + _nvmlCheckReturn(ret) + return (c_vm_id.value, c_vm_id_type.value) + +@convertStrBytes +def nvmlVgpuInstanceGetUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetMdevUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMdevUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + +@convertStrBytes +def nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance): + c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmDriverVersion") + ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size) + _nvmlCheckReturn(ret) + return c_driver_version.value + +def nvmlVgpuInstanceGetLicenseStatus(vgpuInstance): + c_license_status = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseStatus") + ret = fn(vgpuInstance, byref(c_license_status)) + _nvmlCheckReturn(ret) + return c_license_status.value + +def nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseInfo_v2") + c_license_info = c_nvmlVgpuLicenseInfo_t() + ret = fn(vgpuInstance, byref(c_license_info)) + _nvmlCheckReturn(ret) + return c_license_info + +def nvmlVgpuInstanceGetLicenseInfo(vgpuInstance): + return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance) + +def nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance): + c_frl = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFrameRateLimit") + ret = fn(vgpuInstance, byref(c_frl)) + _nvmlCheckReturn(ret) + return c_frl.value + +def nvmlVgpuInstanceGetEccMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEccMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetType(vgpuInstance): + c_vgpu_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetType") + ret = fn(vgpuInstance, byref(c_vgpu_type)) + _nvmlCheckReturn(ret) + return c_vgpu_type.value + +def nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance): + c_encoder_capacity = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderCapacity") + ret = fn(vgpuInstance, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceSetEncoderCapacity") + return fn(vgpuInstance, encoder_capacity) + +def nvmlVgpuInstanceGetFbUsage(vgpuInstance): + c_fb_usage = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFbUsage") + ret = fn(vgpuInstance, byref(c_fb_usage)) + _nvmlCheckReturn(ret) + return c_fb_usage.value + +def nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability): + c_cap_result = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetCapabilities") + ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result)) + _nvmlCheckReturn(ret) + return (c_cap_result.value) + +def nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance): + c_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuInstanceId") + ret = fn(vgpuInstance, byref(c_id)) + _nvmlCheckReturn(ret) + return (c_id.value) + +@convertStrBytes +def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): + c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") + ret = fn(vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE))) + _nvmlCheckReturn(ret) + return c_vgpuPciId.value + +def nvmlDeviceGetVgpuUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_value_type = _nvmlValueType_t() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0) + c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1 + c_vgpuUtilInfo.sampleValType = _nvmlValueType_t() + c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0) + c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuInstancesUtilizationInfo") + ret = fn(handle, byref(c_vgpuUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuUtilInfo.vgpuUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuUtilInfo.vgpuInstanceCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): + c_p2pstatus = _nvmlGpuP2PStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetP2PStatus") + ret = fn(device1, device2,p2pIndex, byref(c_p2pstatus)) + _nvmlCheckReturn(ret) + return c_p2pstatus.value + +def nvmlDeviceGetGridLicensableFeatures_v4(handle): + c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGridLicensableFeatures_v4") + ret = fn(handle, byref(c_get_grid_licensable_features)) + _nvmlCheckReturn(ret) + + return (c_get_grid_licensable_features) + +def nvmlDeviceGetGridLicensableFeatures(handle): + return nvmlDeviceGetGridLicensableFeatures_v4(handle) + +def nvmlDeviceGetGspFirmwareVersion(handle, version=None): + isUserDefined = version is not None + if not isUserDefined: + version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareVersion") + ret = fn(handle, version) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else version.value + +def nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()): + isReference = type(isEnabled) is not c_uint + isEnabledRef = isEnabled if isReference else byref(isEnabled) + defaultModeRef = defaultMode if isReference else byref(defaultMode) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareMode") + ret = fn(handle, isEnabledRef, defaultModeRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value] + +def nvmlDeviceGetEncoderCapacity(handle, encoderQueryType): + c_encoder_capacity = c_ulonglong(0) + c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderCapacity") + ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + +def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessUtilization") + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0) + c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1 + c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0) + c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessesUtilizationInfo") + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + + if (ret == NVML_SUCCESS): + # special case, no active vGPUs + return [] + elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_vgpuProcUtilInfo.vgpuProcessCount] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetEncoderStats(handle): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderStats") + ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlDeviceGetEncoderSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetFBCStats(handle): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCStats") + ret = fn(handle, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlDeviceGetFBCSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCSessions") + ret = fn(handle, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") + ret = fn(vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + +def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetFBCStats(vgpuInstance): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCStats") + ret = fn(vgpuInstance, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + +def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if (ret == NVML_SUCCESS): + if (c_session_count.value != 0): + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessUtilization(handle, timeStamp): + # first call to get the size + c_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessUtilization") + ret = fn(handle, None, byref(c_count), c_time_stamp) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_samples, byref(c_count), c_time_stamp) + _nvmlCheckReturn(ret) + + return c_samples[0:c_count.value] + else: + # error case + raise NVMLError(ret) + +def nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0) + c_processesUtilInfo.version = ProcessesUtilizationInfo_v1 + c_processesUtilInfo.processSamplesCount = c_uint(0) + c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessesUtilizationInfo") + ret = fn(handle, byref(c_processesUtilInfo)) + + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + # typical case + sampleArray = c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t + c_samples = sampleArray() + c_processesUtilInfo.procUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_processesUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0:c_processesUtilInfo.processSamplesCount] + else: + # error case + raise NVMLError(ret) + +def nvmlVgpuInstanceGetMetadata(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMetadata") + c_vgpuMetadata = c_nvmlVgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuMetadata + +def nvmlDeviceGetVgpuMetadata(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuMetadata") + c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuPgpuMetadata + +def nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata): + fn = _nvmlGetFunctionPointer("nvmlGetVgpuCompatibility") + c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t() + ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility)) + _nvmlCheckReturn(ret) + return c_vgpuPgpuCompatibility + +@convertStrBytes +def nvmlDeviceGetPgpuMetadataString(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPgpuMetadataString") + c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pgpuMetadata.value, c_bufferSize.value) + +def nvmlDeviceGetVgpuSchedulerLog(handle): + c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerLog") + ret = fn(handle, byref(c_vgpu_sched_log)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_log + +def nvmlDeviceGetVgpuSchedulerState(handle): + c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerState") + ret = fn(handle, byref(c_vgpu_sched_state)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_state + +def nvmlDeviceGetVgpuSchedulerCapabilities(handle): + c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerCapabilities") + ret = fn(handle, byref(c_vgpu_sched_caps)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_caps + +def nvmlDeviceSetVgpuSchedulerState(handle, sched_state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuSchedulerState") + ret = fn(handle, byref(sched_state)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSetVgpuVersion(vgpuVersion): + fn = _nvmlGetFunctionPointer("nvmlSetVgpuVersion") + ret = fn(byref(vgpuVersion)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetVgpuVersion(supported=None, current=None): + isUserDefined = (supported is not None) or (current is not None) + if not isUserDefined: + supported = c_nvmlVgpuVersion_t() + current = c_nvmlVgpuVersion_t() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuVersion") + ret = fn(byref(supported), byref(current)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else [(supported.minVersion, + supported.maxVersion), + (current.minVersion, + current.maxVersion)] + +def nvmlVgpuInstanceGetAccountingMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + +def nvmlVgpuInstanceGetAccountingPids(vgpuInstance): + c_pidCount = c_uint() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingPids") + ret = fn(vgpuInstance, byref(c_pidCount), None) + if (ret == NVML_ERROR_INSUFFICIENT_SIZE): + sampleArray = c_pidCount.value * c_uint + c_pidArray = sampleArray() + ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pidCount, c_pidArray) + +def nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid): + c_accountingStats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingStats") + ret = fn(vgpuInstance, pid, byref(c_accountingStats)) + _nvmlCheckReturn(ret) + return c_accountingStats + +def nvmlVgpuInstanceClearAccountingPids(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceClearAccountingPids") + ret = fn(vgpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGetExcludedDeviceCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGetExcludedDeviceInfoByIndex(index): + c_index = c_uint(index) + info = c_nvmlExcludedDeviceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceInfoByIndex") + ret = fn(c_index, byref(info)) + _nvmlCheckReturn(ret) + return info + +def nvmlDeviceGetHostVgpuMode(handle): + c_host_vgpu_mode = _nvmlHostVgpuMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHostVgpuMode") + ret = fn(handle, byref(c_host_vgpu_mode)) + _nvmlCheckReturn(ret) + return c_host_vgpu_mode.value + +def nvmlDeviceSetMigMode(device, mode): + c_activationStatus = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMigMode") + ret = fn(device, mode, byref(c_activationStatus)) + _nvmlCheckReturn(ret) + return c_activationStatus.value + +def nvmlDeviceGetMigMode(device): + c_currentMode = c_uint() + c_pendingMode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigMode") + ret = fn(device, byref(c_currentMode), byref(c_pendingMode)) + _nvmlCheckReturn(ret) + return [c_currentMode.value, c_pendingMode.value] + +def nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2): + if version == 2: + c_info = c_nvmlGpuInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlGpuInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo + +def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceRemainingCapacity") + ret = fn(device, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetGpuInstancePossiblePlacements(device, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + ret = fn(device, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceCreateGpuInstance(device, profileId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstance") + ret = fn(device, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstanceWithPlacement") + ret = fn(device, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceDestroy(gpuInstance): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceDestroy") + ret = fn(gpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstances") + ret = fn(device, profileId, gpuInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuInstanceById(device, gpuInstanceId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceById") + ret = fn(device, gpuInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceGetInfo(gpuInstance): + c_info = c_nvmlGpuInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetInfo") + ret = fn(gpuInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlGpuInstanceGetComputeInstanceProfileInfo(device, profile, engProfile, version=2): + if version == 2: + c_info = c_nvmlComputeInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlComputeInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, engProfile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +# Define function alias for the API exposed by NVML +nvmlGpuInstanceGetComputeInstanceProfileInfoV = nvmlGpuInstanceGetComputeInstanceProfileInfo + +def nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceRemainingCapacity") + ret = fn(gpuInstance, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlGpuInstanceGetComputeInstancePossiblePlacements(gpuInstance, profileId, placementsRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstancePossiblePlacements") + ret = fn(gpuInstance, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstance") + ret = fn(gpuInstance, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlGpuInstanceCreateComputeInstanceWithPlacement(gpuInstance, profileId, placement): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstanceWithPlacement") + ret = fn(gpuInstance, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceDestroy(computeInstance): + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceDestroy") + ret = fn(computeInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstances(gpuInstance, profileId, computeInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") + ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceById") + ret = fn(gpuInstance, computeInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + +def nvmlComputeInstanceGetInfo_v2(computeInstance): + c_info = c_nvmlComputeInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceGetInfo_v2") + ret = fn(computeInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + +def nvmlComputeInstanceGetInfo(computeInstance): + return nvmlComputeInstanceGetInfo_v2(computeInstance) + +def nvmlDeviceIsMigDeviceHandle(device): + c_isMigDevice = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceIsMigDeviceHandle") + ret = fn(device, byref(c_isMigDevice)) + _nvmlCheckReturn(ret) + return c_isMigDevice + +def nvmlDeviceGetGpuInstanceId(device): + c_gpuInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceId") + ret = fn(device, byref(c_gpuInstanceId)) + _nvmlCheckReturn(ret) + return c_gpuInstanceId.value + +def nvmlDeviceGetComputeInstanceId(device): + c_computeInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeInstanceId") + ret = fn(device, byref(c_computeInstanceId)) + _nvmlCheckReturn(ret) + return c_computeInstanceId.value + +def nvmlDeviceGetMaxMigDeviceCount(device): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxMigDeviceCount") + ret = fn(device, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + +def nvmlDeviceGetMigDeviceHandleByIndex(device, index): + c_index = c_uint(index) + migDevice = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigDeviceHandleByIndex") + ret = fn(device, c_index, byref(migDevice)) + _nvmlCheckReturn(ret) + return migDevice + +def nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice): + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDeviceHandleFromMigDeviceHandle") + ret = fn(migDevice, byref(device)) + _nvmlCheckReturn(ret) + return device + +def nvmlDeviceGetAttributes_v2(device): + c_attrs = c_nvmlDeviceAttributes() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAttributes_v2") + ret = fn(device, byref(c_attrs)) + _nvmlCheckReturn(ret) + return c_attrs + +def nvmlDeviceGetAttributes(device): + return nvmlDeviceGetAttributes_v2(device) + +def nvmlDeviceGetRemappedRows(device): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRemappedRows") + c_corr = c_uint() + c_unc = c_uint() + c_bpending = c_uint() + c_bfailure = c_uint() + ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure)) + _nvmlCheckReturn(ret) + return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value) + +def nvmlDeviceGetRowRemapperHistogram(device): + c_vals = c_nvmlRowRemapperHistogramValues() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRowRemapperHistogram") + ret = fn(device, byref(c_vals)) + _nvmlCheckReturn(ret) + return c_vals + +def nvmlDeviceGetArchitecture(device): + arch = _nvmlDeviceArchitecture_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetArchitecture") + ret = fn(device, byref(arch)) + _nvmlCheckReturn(ret) + return arch.value + +def nvmlDeviceGetBusType(device): + c_busType = _nvmlBusType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBusType") + ret = fn(device, byref(c_busType)) + _nvmlCheckReturn(ret) + return c_busType.value + +def nvmlDeviceGetIrqNum(device): + c_irqNum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIrqNum") + ret = fn(device, byref(c_irqNum)) + _nvmlCheckReturn(ret) + return c_irqNum.value + +def nvmlDeviceGetNumGpuCores(device): + c_numCores = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumGpuCores") + ret = fn(device, byref(c_numCores)) + _nvmlCheckReturn(ret) + return c_numCores.value + +def nvmlDeviceGetPowerSource(device): + c_powerSource = _nvmlPowerSource_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerSource") + ret = fn(device, byref(c_powerSource)) + _nvmlCheckReturn(ret) + return c_powerSource.value + +def nvmlDeviceGetMemoryBusWidth(device): + c_memBusWidth = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryBusWidth") + ret = fn(device, byref(c_memBusWidth)) + _nvmlCheckReturn(ret) + return c_memBusWidth.value + +def nvmlDeviceGetPcieLinkMaxSpeed(device): + c_speed = _nvmlPcieLinkMaxSpeed_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieLinkMaxSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetAdaptiveClockInfoStatus(device): + c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAdaptiveClockInfoStatus") + ret = fn(device, byref(c_adaptiveClockInfoStatus)) + _nvmlCheckReturn(ret) + return c_adaptiveClockInfoStatus.value + +def nvmlDeviceGetPcieSpeed(device): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + +def nvmlDeviceGetDynamicPstatesInfo(device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t()): + isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t + dynamicpstatesinfoRef = c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDynamicPstatesInfo"); + ret = fn(device, dynamicpstatesinfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_dynamicpstatesinfo + +def nvmlDeviceSetFanSpeed_v2(handle, index, speed): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanSpeed_v2"); + ret = fn(handle, index, speed) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetThermalSettings(device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t()): + isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t + thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetThermalSettings"); + ret = fn(device, sensorindex, thermalsettingsRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:] + +def nvmlDeviceGetMinMaxClockOfPState(device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint()): + isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint) + minClockMHzRef = minClockMHz if isReference else byref(minClockMHz) + maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxClockOfPState"); + ret = fn(device, _nvmlClockType_t(clockType), _nvmlClockType_t(pstate), minClockMHzRef, maxClockMHzRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value) + +class c_nvmlClockOffset_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('type', _nvmlClockType_t), + ('pstate', _nvmlPstates_t), + ('clockOffsetMHz', c_int), + ('minClockOffsetMHz', c_int), + ('maxClockOffsetMHz', c_int), + ] + +nvmlClockOffset_v1 = 0x1000018 + +def nvmlDeviceGetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceSetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetClockOffsets"); + ret = fn(device, info) + return NVML_SUCCESS + +def nvmlDeviceGetSupportedPerformanceStates(device): + pstates = [] + c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES) + c_size = sizeof(c_uint)*c_count.value + + # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration. + pstates_array = _nvmlPstates_t * c_count.value + c_pstates = pstates_array() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedPerformanceStates") + ret = fn(device, c_pstates, c_size) + _nvmlCheckReturn(ret) + + for value in c_pstates: + if value != NVML_PSTATE_UNKNOWN: + pstates.append(value) + + return pstates + +def nvmlDeviceGetGpcClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetGpcClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpcClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlDeviceGetMemClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + +def nvmlDeviceSetMemClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + +def nvmlSystemSetConfComputeGpusReadyState(state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeGpusReadyState") + ret = fn(c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeGpusReadyState(): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeGpusReadyState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +def nvmlSystemGetConfComputeCapabilities(): + c_ccSysCaps = c_nvmlConfComputeSystemCaps_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeCapabilities") + ret = fn(byref(c_ccSysCaps)) + _nvmlCheckReturn(ret) + return c_ccSysCaps + +def nvmlSystemGetConfComputeState(): + c_state = c_nvmlConfComputeSystemState_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + +def nvmlSystemGetConfComputeSettings(settings): + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeSettings") + return fn(settings) + +def nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetConfComputeUnprotectedMemSize") + ret = fn(device, c_ccMemSize) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetConfComputeMemSizeInfo(device): + c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeMemSizeInfo") + ret = fn(device, byref(c_ccMemSize)) + _nvmlCheckReturn(ret) + return c_ccMemSize + +def nvmlDeviceGetConfComputeProtectedMemoryUsage(device): + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeProtectedMemoryUsage") + ret = fn(device, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + +def nvmlDeviceGetConfComputeGpuCertificate(device): + c_cert = c_nvmlConfComputeGpuCertificate_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuCertificate") + ret = fn(device, byref(c_cert)) + _nvmlCheckReturn(ret) + return c_cert + +def nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce): + c_attestReport = c_nvmlConfComputeGpuAttestationReport_t() + c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce)) + setattr(c_attestReport, 'nonce', c_nonce_arr) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuAttestationReport") + ret = fn(device, byref(c_attestReport)) + _nvmlCheckReturn(ret) + return c_attestReport + +def nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv): + c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1 + c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetConfComputeKeyRotationThresholdInfo(): + c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return c_keyRotationThrInfo + +## GPM ## +######### + +## Enums/defines + +#### GPM Metric Identifiers +NVML_GPM_METRIC_GRAPHICS_UTIL = 1 # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0 +NVML_GPM_METRIC_SM_UTIL = 2 # Percentage of SMs that were busy. 0.0 - 100.0 +NVML_GPM_METRIC_SM_OCCUPANCY = 3 # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_INTEGER_UTIL = 4 # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0 +NVML_GPM_METRIC_ANY_TENSOR_UTIL = 5 # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DFMA_TENSOR_UTIL = 6 # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_HMMA_TENSOR_UTIL = 7 # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_IMMA_TENSOR_UTIL = 9 # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0 +NVML_GPM_METRIC_DRAM_BW_UTIL = 10 # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0 +NVML_GPM_METRIC_FP64_UTIL = 11 # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP32_UTIL = 12 # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0 +NVML_GPM_METRIC_FP16_UTIL = 13 # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0 +NVML_GPM_METRIC_PCIE_TX_PER_SEC = 20 # PCIe traffic from this GPU in MiB/sec +NVML_GPM_METRIC_PCIE_RX_PER_SEC = 21 # PCIe traffic to this GPU in MiB/sec +NVML_GPM_METRIC_NVDEC_0_UTIL = 30 # Percent utilization of NVDEC 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_1_UTIL = 31 # Percent utilization of NVDEC 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_2_UTIL = 32 # Percent utilization of NVDEC 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_3_UTIL = 33 # Percent utilization of NVDEC 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_4_UTIL = 34 # Percent utilization of NVDEC 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_5_UTIL = 35 # Percent utilization of NVDEC 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_6_UTIL = 36 # Percent utilization of NVDEC 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_7_UTIL = 37 # Percent utilization of NVDEC 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_0_UTIL = 40 # Percent utilization of NVJPG 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_1_UTIL = 41 # Percent utilization of NVJPG 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_2_UTIL = 42 # Percent utilization of NVJPG 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_3_UTIL = 43 # Percent utilization of NVJPG 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_4_UTIL = 44 # Percent utilization of NVJPG 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_5_UTIL = 45 # Percent utilization of NVJPG 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_6_UTIL = 46 # Percent utilization of NVJPG 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_7_UTIL = 47 # Percent utilization of NVJPG 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_0_UTIL = 50 # Percent utilization of NVOFA 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_1_UTIL = 51 # Percent utilization of NVOFA 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = 60 # NvLink read bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = 61 # NvLink write bandwidth for all links in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62 # NvLink read bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = 63 # NvLink write bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64 # NvLink read bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = 65 # NvLink write bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66 # NvLink read bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = 67 # NvLink write bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68 # NvLink read bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = 69 # NvLink write bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70 # NvLink read bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = 71 # NvLink write bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72 # NvLink read bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = 73 # NvLink write bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74 # NvLink read bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = 75 # NvLink write bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76 # NvLink read bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = 77 # NvLink write bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78 # NvLink read bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = 79 # NvLink write bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80 # NvLink read bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = 81 # NvLink write bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = 82 # NvLink read bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = 83 # NvLink write bandwidth for link 10 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = 84 # NvLink read bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = 85 # NvLink write bandwidth for link 11 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = 86 # NvLink read bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = 87 # NvLink write bandwidth for link 12 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = 88 # NvLink read bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = 89 # NvLink write bandwidth for link 13 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = 90 # NvLink read bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = 91 # NvLink write bandwidth for link 14 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = 92 # NvLink read bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = 93 # NvLink write bandwidth for link 15 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = 94 # NvLink read bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = 95 # NvLink write bandwidth for link 16 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = 96 # NvLink read bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = 97 # NvLink write bandwidth for link 17 in MiB/sec +NVML_GPM_METRIC_MAX = 98 + +## Structs + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ('name', c_char * 96), + ('id', c_char * 96), + ('serial', c_char * 96), + ('firmwareVersion', c_char * 96), + ] + +class struct_c_nvmlGpmSample_t(Structure): + pass # opaque handle +c_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t) + +class c_metricInfo_t(Structure): + _fields_ = [ + ("shortName", c_char_p), + ("longName", c_char_p), + ("unit", c_char_p), + ] + +class c_nvmlGpmMetric_t(_PrintableStructure): + _fields_ = [ + ('metricId', c_uint), + ('nvmlReturn', _nvmlReturn_t), + ('value', c_double), + ('metricInfo', c_metricInfo_t) + ] + +class c_nvmlGpmMetricsGet_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('numMetrics', c_uint), + ('sample1', c_nvmlGpmSample_t), + ('sample2', c_nvmlGpmSample_t), + ('metrics', c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX) + ] + +NVML_GPM_METRICS_GET_VERSION = 1 + +class c_nvmlGpmSupport_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('isSupportedDevice', c_uint), + ] + +NVML_GPM_SUPPORT_VERSION = 1 + +## Functions + +def nvmlGpmMetricsGet(metricsGet): + fn = _nvmlGetFunctionPointer("nvmlGpmMetricsGet") + ret = fn(byref(metricsGet)) + _nvmlCheckReturn(ret) + return metricsGet + +def nvmlGpmSampleFree(gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleFree") + ret = fn(gpmSample) + _nvmlCheckReturn(ret) + return + +def nvmlGpmSampleAlloc(): + gpmSample = c_nvmlGpmSample_t() + fn = _nvmlGetFunctionPointer("nvmlGpmSampleAlloc") + ret = fn(byref(gpmSample)) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmSampleGet(device, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleGet") + ret = fn(device, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmMigSampleGet") + ret = fn(device, gpuInstanceId, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + +def nvmlGpmQueryDeviceSupport(device): + gpmSupport = c_nvmlGpmSupport_t() + gpmSupport.version = NVML_GPM_SUPPORT_VERSION + fn = _nvmlGetFunctionPointer("nvmlGpmQueryDeviceSupport") + ret = fn(device, byref(gpmSupport)) + _nvmlCheckReturn(ret) + return gpmSupport + +def nvmlGpmSetStreamingEnabled(device, state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlGpmSetStreamingEnabled") + ret = fn(device, c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlGpmQueryIfStreamingEnabled(device): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpmQueryIfStreamingEnabled") + ret = fn(device, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + +# Low Power Structure and Function + +NVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0 +NVML_NVLINK_POWER_STATE_LOW = 0x1 + +NVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1 +NVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF +NVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF +NVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET + +class c_nvmlNvLinkPowerThres_t(Structure): + _fields_ = [ + ("lowPwrThreshold", c_uint), + ] + +def nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold): + c_info = c_nvmlNvLinkPowerThres_t() + c_info.lowPwrThreshold = l1threshold + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkDeviceLowPowerThreshold") + ret = fn(device, byref(c_info)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_GPU_FABRIC_UUID_LEN = 16 + +_nvmlGpuFabricState_t = c_uint +NVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_STATE_NOT_STARTED = 1 +NVML_GPU_FABRIC_STATE_IN_PROGRESS = 2 +NVML_GPU_FABRIC_STATE_COMPLETED = 3 + +class c_nvmlGpuFabricInfo_t(_PrintableStructure): + _fields_ = [ + ("clusterUuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t) + ] + +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11 + +nvmlGpuFabricInfo_v2 = 0x02000024 + +class c_nvmlGpuFabricInfoV_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("clusterUuid", c_char * NVML_GPU_FABRIC_UUID_LEN), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ("healthMask", c_uint32) + ] + + def __init__(self): + super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2) + +def nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfo"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfoV"); + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +###################### +## Enums/defines +#### NVML GPU NVLINK BW MODE +NVML_GPU_NVLINK_BW_MODE_FULL = 0x0 +NVML_GPU_NVLINK_BW_MODE_OFF = 0x1 +NVML_GPU_NVLINK_BW_MODE_MIN = 0x2 +NVML_GPU_NVLINK_BW_MODE_HALF = 0x3 +NVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4 +NVML_GPU_NVLINK_BW_MODE_COUNT = 0x5 + +def nvmlSystemSetNvlinkBwMode(mode): + fn = _nvmlGetFunctionPointer("nvmlSystemSetNvlinkBwMode") + ret = fn(mode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlSystemGetNvlinkBwMode(): + mode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetNvlinkBwMode") + ret = fn(byref(mode)) + _nvmlCheckReturn(ret) + return mode.value + +_nvmlPowerScopeType_t = c_uint +NVML_POWER_SCOPE_GPU = 0 +NVML_POWER_SCOPE_MODULE = 1 +NVML_POWER_SCOPE_MEMORY = 2 + +class c_nvmlPowerValue_v2_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('powerScope', _nvmlPowerScopeType_t), + ('powerValueMw', c_uint), + ] + _fmt_ = {'': "%d B"} + +nvmlPowerValue_v2 = 0x0200000C + +def nvmlDeviceSetPowerManagementLimit_v2(device, powerScope, powerLimit, version=nvmlPowerValue_v2): + c_powerScope = _nvmlPowerScopeType_t(powerScope) + c_powerValue = c_nvmlPowerValue_v2_t() + c_powerValue.version = c_uint(version) + c_powerValue.powerScope = c_powerScope + c_powerValue.powerValueMw = c_uint(powerLimit) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit_v2") + ret = fn(device, byref(c_powerValue)) + return NVML_SUCCESS + +class c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('aggregateUncParity', c_ulonglong), + ('aggregateUncSecDed', c_ulonglong), + ('aggregateCor', c_ulonglong), + ('volatileUncParity', c_ulonglong), + ('volatileUncSecDed', c_ulonglong), + ('volatileCor', c_ulonglong), + ('aggregateUncBucketL2', c_ulonglong), + ('aggregateUncBucketSm', c_ulonglong), + ('aggregateUncBucketPcie', c_ulonglong), + ('aggregateUncBucketMcu', c_ulonglong), + ('aggregateUncBucketOther', c_ulonglong), + ('bThresholdExceeded', c_uint) + ] + + def __init__(self): + super(c_nvmlEccSramErrorStatus_v1_t, self).__init__(version=nvmlEccSramErrorStatus_v1) + +nvmlEccSramErrorStatus_v1 = 0x1000068 +def nvmlDeviceGetSramEccErrorStatus(device, status): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSramEccErrorStatus") + ret = fn(device, status) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +NVML_DEV_CAP_EGM = (1 << 0) +nvmlDeviceCapabilities_v1 = 0x1000008 + +class c_nvmlDeviceCapabilities_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('capMask', c_uint), + ] + + def __init__(self): + super(c_nvmlDeviceCapabilities_v1_t, self).__init__(version=nvmlDeviceCapabilities_v1) + + +def nvmlDeviceGetCapabilities(device, caps): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCapabilities") + return fn(device, caps) + +class c_nvmlPlatformInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('ibGuid', c_char * 16), + ('rackGuid', c_char * 16), + ('chassisPhysicalSlotNumber', c_char), + ('computeSlotIndex', c_char), + ('nodeIndex', c_char), + ('peerType', c_char), + ('moduleId', c_char) + ] + + def __init__(self): + super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1) + +nvmlPlatformInfo_v1 = 0x100002c +def nvmlDeviceGetPlatformInfo(device, platformInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPlatformInfo") + ret = fn(device, platformInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +class c_nvmlMask255_t(_PrintableStructure): + _fields_ = [ + ('mask', c_uint * 8), + ] + +NVML_WORKLOAD_POWER_MAX_PROFILES = 255 +NVML_POWER_PROFILE_MAX_P = 0 +NVML_POWER_PROFILE_MAX_Q = 1 +NVML_POWER_PROFILE_COMPUTE = 2 +NVML_POWER_PROFILE_MEMORY_BOUND = 3 +NVML_POWER_PROFILE_NETWORK = 4 +NVML_POWER_PROFILE_BALANCED = 5 +NVML_POWER_PROFILE_LLM_INFERENCE = 6 +NVML_POWER_PROFILE_LLM_TRAINING = 7 +NVML_POWER_PROFILE_RBM = 8 +NVML_POWER_PROFILE_DCPCIE = 9 +NVML_POWER_PROFILE_HMMA_SPARSE = 10 +NVML_POWER_PROFILE_HMMA_DENSE = 11 +NVML_POWER_PROFILE_SYNC_BALANCED = 12 +NVML_POWER_PROFILE_HPC = 13 +NVML_POWER_PROFILE_MIG = 14 +NVML_POWER_PROFILE_MAX = 15 + +nvmlWorkloadPowerProfileInfo_v1 = 0x100002c +class c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('priority', c_uint), + ('conflictingmask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileInfo_v1) + +nvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002bf8 +class c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('perfProfile', c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__(version=nvmlWorkloadPowerProfileProfilesInfo_v1) + +nvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064 +class c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('perfProfilesMask', c_nvmlMask255_t), + ('requestedProfilesMask', c_nvmlMask255_t), + ('enforcedProfilesMask', c_nvmlMask255_t) + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileCurrentProfiles_v1) + +nvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024 +class c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('requestedProfilesMask', c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__(version=nvmlWorkloadPowerProfileRequestedProfiles_v1) + +def nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetProfilesInfo") + ret = fn(device, profilesInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetCurrentProfiles") + ret = fn(device, currentProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileSetRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileClearRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkSupportedBwModes") + ret = fn(device, supportedBwModes) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceGetNvlinkBwMode(device, getBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkBwMode") + ret = fn(device, getBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +def nvmlDeviceSetNvlinkBwMode(device, setBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvlinkBwMode") + ret = fn(device, setBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + +nvmlDramEncryptionInfo_v1 = 0x01000008 + +class c_nvmlDramEncryptionInfo_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('encryptionState', _nvmlEnableState_t), + ] + + def __init__(self): + super(c_nvmlDramEncryptionInfo_t, self).__init__(version=nvmlDramEncryptionInfo_v1) + +def nvmlDeviceGetDramEncryptionMode(handle): + c_currState = c_nvmlDramEncryptionInfo_t() + c_pendingState = c_nvmlDramEncryptionInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDramEncryptionMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.encryptionState, c_pendingState.encryptionState] + +# added to API +def nvmlDeviceGetCurrentDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[0] + +# added to API +def nvmlDeviceGetPendingDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[1] + +def nvmlDeviceSetDramEncryptionMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDramEncryptionMode") + c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t() + c_dramEncryptionMode.encryptionState = mode; + ret = fn(handle, byref(c_dramEncryptionMode)) + _nvmlCheckReturn(ret) + return None + +# Power Smoothing defines +NVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5 +NVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF +NVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3 + +nvmlPowerSmoothingState_v1=0x1000008 +class c_nvmlPowerSmoothingState_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('state', c_uint), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingState_v1_t, self).__init__(version=nvmlPowerSmoothingState_v1) + +nvmlPowerSmoothingProfile_v1=0x1000018 +class c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure): + _fields_ = [ + ('version', c_uint), + ('profileId', c_uint), + ('paramId', c_uint), + ('value', c_double), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__(version=nvmlPowerSmoothingProfile_v1) + +def nvmlDevicePowerSmoothingActivatePresetProfile(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingActivatePresetProfile") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingUpdatePresetProfileParam") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + +def nvmlDevicePowerSmoothingSetState(device, state): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingSetState") + ret = fn(device, state) + _nvmlCheckReturn(ret) + diff --git a/vllm/utils.py b/vllm/utils.py index 8b9269598..e16875276 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2239,34 +2239,13 @@ def import_pynvml(): This causes errors when both of them are installed. Starting from version 12.0, it migrates to a new module named `pynvml_utils` to avoid the conflict. - - TL;DR: if users have pynvml<12.0 installed, it will cause problems. - Otherwise, `import pynvml` will import the correct module. - We take the safest approach here, to manually import the correct - `pynvml.py` module from the `nvidia-ml-py` package. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. """ - if TYPE_CHECKING: - import pynvml - return pynvml - if "pynvml" in sys.modules: - import pynvml - if pynvml.__file__.endswith("__init__.py"): - # this is pynvml < 12.0 - raise RuntimeError( - "You are using a deprecated `pynvml` package. " - "Please uninstall `pynvml` or upgrade to at least" - " version 12.0. See https://pypi.org/project/pynvml " - "for more information.") - return sys.modules["pynvml"] - import importlib.util - import os - import site - for site_dir in site.getsitepackages(): - pynvml_path = os.path.join(site_dir, "pynvml.py") - if os.path.exists(pynvml_path): - spec = importlib.util.spec_from_file_location( - "pynvml", pynvml_path) - pynvml = importlib.util.module_from_spec(spec) - sys.modules["pynvml"] = pynvml - spec.loader.exec_module(pynvml) - return pynvml + import vllm.third_party.pynvml as pynvml + return pynvml -- GitLab From 29f1d47e73de3764c944a0af0ff10bbc8ce244f4 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Sun, 9 Feb 2025 02:56:40 -0800 Subject: [PATCH 048/253] [MISC] Always import version library first in the vllm package (#12979) Signed-off-by: Lu Fang --- vllm/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index 566c5116d..457780824 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +# The version.py should be independent library, and we always import the +# version library first. Such assumption is critical for some customization. +from .version import __version__, __version_tuple__ # isort:skip + import os import torch @@ -19,8 +23,6 @@ from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput, from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from .version import __version__, __version_tuple__ - # set some common config/environment variables that should be set # for all processes created by vllm and all processes # that interact with vllm workers. -- GitLab From 59fff4a01ae0f5c887cc547af6b49a9b028b4c70 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Feb 2025 09:38:57 +0800 Subject: [PATCH 049/253] [core] improve error handling when wake up from sleep mode (#12981) Signed-off-by: youkaichao --- csrc/cumem_allocator.cpp | 63 ++++++++++++++++++++++----- tests/basic_correctness/test_cumem.py | 27 ++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/csrc/cumem_allocator.cpp b/csrc/cumem_allocator.cpp index e8555d853..fab6ca36d 100644 --- a/csrc/cumem_allocator.cpp +++ b/csrc/cumem_allocator.cpp @@ -12,15 +12,21 @@ extern "C" { #include #include -#define CUDA_CHECK(condition) \ - do { \ - CUresult error = condition; \ - if (error != 0) { \ - char* error_string; \ - cuGetErrorString(error, (const char**)&error_string); \ - std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \ - << __LINE__ << std::endl; \ - } \ +char error_msg[10240]; // 10KB buffer to store error messages +CUresult no_error = CUresult(0); +CUresult error_code = no_error; // store error code + +#define CUDA_CHECK(condition) \ + do { \ + CUresult error = condition; \ + if (error != 0) { \ + error_code = error; \ + char* error_string; \ + cuGetErrorString(error, (const char**)&error_string); \ + snprintf(error_msg, sizeof(error_msg), "CUDA Error: %s at %s:%d", \ + error_string, __FILE__, __LINE__); \ + std::cerr << error_msg << std::endl; \ + } \ } while (0) // Global references to Python callables @@ -54,14 +60,22 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem, // Allocate memory using cuMemCreate CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0)); + if (error_code != 0) { + return; + } CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0)); - + if (error_code != 0) { + return; + } CUmemAccessDesc accessDesc = {}; accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; accessDesc.location.id = device; accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1)); + if (error_code != 0) { + return; + } // std::cout << "create_and_map: device=" << device << ", size=" << size << ", // d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; } @@ -73,7 +87,13 @@ void unmap_and_release(unsigned long long device, ssize_t size, // ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl; ensure_context(device); CUDA_CHECK(cuMemUnmap(d_mem, size)); + if (error_code != 0) { + return; + } CUDA_CHECK(cuMemRelease(*p_memHandle)); + if (error_code != 0) { + return; + } } PyObject* create_tuple_from_c_integers(unsigned long long a, @@ -121,12 +141,16 @@ void* my_malloc(ssize_t size, int device, CUstream stream) { size_t granularity; CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - + if (error_code != 0) { + return nullptr; + } size_t alignedSize = ((size + granularity - 1) / granularity) * granularity; CUdeviceptr d_mem; CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0)); - + if (error_code != 0) { + return nullptr; + } // allocate the CUmemGenericAllocationHandle CUmemGenericAllocationHandle* p_memHandle = (CUmemGenericAllocationHandle*)malloc( @@ -208,6 +232,9 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) { // free address and the handle CUDA_CHECK(cuMemAddressFree(d_mem, size)); + if (error_code != 0) { + return; + } free(p_memHandle); } @@ -258,6 +285,12 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) { unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle); + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + Py_RETURN_NONE; } @@ -282,6 +315,12 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) { create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle); + if (error_code != 0) { + error_code = no_error; + PyErr_SetString(PyExc_RuntimeError, error_msg); + return nullptr; + } + Py_RETURN_NONE; } diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index da9239b09..4e9f1bf1c 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest import torch from vllm import LLM, SamplingParams @@ -9,6 +10,32 @@ from vllm.utils import GiB_bytes from ..utils import fork_new_process_for_each_test +@fork_new_process_for_each_test +def test_python_error(): + """ + Test if Python error occurs when there's low-level + error happening from the C++ side. + """ + allocator = CuMemAllocator.get_instance() + total_bytes = torch.cuda.mem_get_info()[1] + alloc_bytes = int(total_bytes * 0.7) + tensors = [] + with allocator.use_memory_pool(): + # allocate 70% of the total memory + x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + tensors.append(x) + # release the memory + allocator.sleep() + + # allocate more memory than the total memory + y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + tensors.append(y) + with pytest.raises(RuntimeError): + # when the allocator is woken up, it should raise an error + # because we don't have enough memory + allocator.wake_up() + + @fork_new_process_for_each_test def test_basic_cumem(): # some tensors from default memory pool -- GitLab From aa0ca5ebb7936587b4acde66cc466495b358be04 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Feb 2025 10:28:59 +0800 Subject: [PATCH 050/253] [core][rlhf] add colocate example for RLHF (#12984) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 4 +- .../{ray_placement.py => rlhf_colocate.py} | 84 +++++++++++++++++-- 2 files changed, 78 insertions(+), 10 deletions(-) rename examples/offline_inference/{ray_placement.py => rlhf_colocate.py} (56%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ab6a576b2..948eab97f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -128,7 +128,7 @@ steps: - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile - examples/offline_inference/rlhf.py - - examples/offline_inference/ray_placement.py + - examples/offline_inference/rlhf_colocate.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -137,7 +137,7 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - python3 ../examples/offline_inference/rlhf.py - - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py + - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/rlhf_colocate.py similarity index 56% rename from examples/offline_inference/ray_placement.py rename to examples/offline_inference/rlhf_colocate.py index cd801a3c0..b921bc71f 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 """ -a simple demonstration to show how to control -the placement of the vLLM workers with Ray. -The key is to set VLLM_RAY_PER_WORKER_GPUS and -VLLM_RAY_BUNDLE_INDICES properly. +a simple demonstration to show how to co-locate +vLLM worker with training actors on the same GPUs, +for RLHF-like applications. +The key points: +- Control the placement of the vLLM workers with Ray, by setting + VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly. +- Use cuda-ipc to pass tensors, since NCCL does not work when we have + multiple processes on the same GPU. """ import os import ray +import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -19,7 +24,33 @@ class MyWorker(Worker): def report_device_id(self) -> str: from vllm.platforms import current_platform - return current_platform.get_device_uuid(self.device.index) + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated class MyLLM(LLM): @@ -40,12 +71,32 @@ class MyLLM(LLM): class RayTrainingActor: - def report_device_id(self) -> str: + def __init__(self): + # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs + from transformers import AutoModelForCausalLM + self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + self.model.to("cuda:0") + for name, p in self.model.named_parameters(): + p.data.zero_() + torch.cuda.synchronize() # the argument for get_device_uuid is the index # of the GPU in the visible devices. - # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs from vllm.platforms import current_platform - return current_platform.get_device_uuid(0) + self.device_uuid = current_platform.get_device_uuid(0) + + def report_device_id(self) -> str: + return self.device_uuid + + def get_weight_ipc_handles(self): + from torch.multiprocessing.reductions import reduce_tensor + data = {} + for name, p in self.model.named_parameters(): + # the training actor might only have a subset of the weights + # and need to all-gather the weights from all the actors. + # for demonstration, here we assume all training actors have + # the full weights. + data[name] = reduce_tensor(p.detach()) + return {self.device_uuid: data} # ray manages 4 GPUs @@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]: ), )(RayTrainingActor).remote() training_actors.append(training_actor) + +for bundle_index, training_actor in enumerate(training_actors): device_id = ray.get(training_actor.report_device_id.remote()) print(f"training actor {bundle_index} is on {device_id}") training_actor_device_ids.append(device_id) @@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the last two training actors should be # on the same GPUs as the second inference engine assert training_actor_device_ids[2:] == inference_engine_device_ids[1] + +print("gather all the IPC handles from the training actors") +ipc_handles = {} +for actor in training_actors: + ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + +print("update the weights of the inference engines") +for llm in inference_engines: + ray.get( + llm.collective_rpc.remote("update_weights_from_ipc_handles", + args=(ipc_handles, ))) +print("check if the weights are updated") +for llm in inference_engines: + assert ray.get( + llm.collective_rpc.remote("check_weights_changed", args=tuple())) -- GitLab From 67c4637ccfd1f1b4e4aa2b645a5635096cf6d1fe Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 9 Feb 2025 19:35:56 -0800 Subject: [PATCH 051/253] [V1] Use msgpack for core request serialization (#12918) Signed-off-by: Nick Hill --- vllm/v1/engine/__init__.py | 42 ++++++++---------------- vllm/v1/engine/core.py | 61 +++++++++++++++-------------------- vllm/v1/engine/core_client.py | 27 +++++++--------- vllm/v1/serial_utils.py | 27 +++++++--------- 4 files changed, 62 insertions(+), 95 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index b05ef3cc8..30e118501 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,20 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import msgspec +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange +from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors -if TYPE_CHECKING: - from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.inputs import PlaceholderRange - from vllm.sampling_params import SamplingParams - # These are possible values of RequestOutput.finish_reason, # so form part of the external API. FINISH_REASON_STRINGS = ("stop", "length", "abort") @@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum): return FINISH_REASON_STRINGS[self.value] -@dataclass -class EngineCoreRequest: +class EngineCoreRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # but this object is currently not playing well with msgspec @@ -51,13 +51,13 @@ class EngineCoreRequest: # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: List[int] - mm_inputs: Optional[List[Optional["MultiModalKwargs"]]] + mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_hashes: Optional[List[str]] - mm_placeholders: Optional[List["PlaceholderRange"]] - sampling_params: "SamplingParams" + mm_placeholders: Optional[List[PlaceholderRange]] + sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float - lora_request: Optional["LoRARequest"] + lora_request: Optional[LoRARequest] class EngineCoreOutput( @@ -94,16 +94,6 @@ class EngineCoreOutputs( scheduler_stats: SchedulerStats -@dataclass -class EngineCoreProfile: - is_start: bool - - -@dataclass -class EngineCoreResetPrefixCache: - pass - - class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum): ABORT = b'\x01' PROFILE = b'\x02' RESET_PREFIX_CACHE = b'\x03' - - -EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, - EngineCoreResetPrefixCache, List[str]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f3d40aa1e..c90667ba0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import pickle import queue import signal import threading import time from multiprocessing.connection import Connection -from typing import List, Tuple, Type +from typing import Any, List, Tuple, Type import psutil import zmq @@ -19,13 +18,12 @@ from vllm.transformers_utils.config import ( from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore): # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() + self.input_queue: queue.Queue[Tuple[EngineCoreRequestType, + Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, args=(input_path, ), @@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore): while True: try: req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) - self._handle_client_request(req) + self._handle_client_request(*req) break except queue.Empty: logger.debug("EngineCore busy loop waiting.") @@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore): except BaseException: raise - # 2) Handle any new client requests (Abort or Add). + # 2) Handle any new client requests. while not self.input_queue.empty(): req = self.input_queue.get_nowait() - self._handle_client_request(req) + self._handle_client_request(*req) # 3) Step the engine core. outputs = self.step() @@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore): # 5) Put EngineCoreOutputs into the output queue. self.output_queue.put_nowait(outputs) - def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: - """Handle EngineCoreRequest or EngineCoreABORT from Client.""" + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + """Dispatch request from client.""" - if isinstance(request, EngineCoreRequest): + if request_type == EngineCoreRequestType.ADD: self.add_request(request) - elif isinstance(request, EngineCoreProfile): - self.model_executor.profile(request.is_start) - elif isinstance(request, EngineCoreResetPrefixCache): - self.reset_prefix_cache() - else: - # TODO: make an EngineCoreAbort wrapper - assert isinstance(request, list) + elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: + self.reset_prefix_cache() + elif request_type == EngineCoreRequestType.PROFILE: + self.model_executor.profile(request) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. - decoder_add_req = PickleEncoder() - decoder_abort_req = PickleEncoder() + add_request_decoder = MsgpackDecoder(EngineCoreRequest) + generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = type_frame.buffer - request_data = data_frame.buffer + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - if request_type == EngineCoreRequestType.ADD.value: - request = decoder_add_req.decode(request_data) - elif request_type == EngineCoreRequestType.ABORT.value: - request = decoder_abort_req.decode(request_data) - elif request_type in ( - EngineCoreRequestType.PROFILE.value, - EngineCoreRequestType.RESET_PREFIX_CACHE.value): - request = pickle.loads(request_data) - else: - raise ValueError(f"Unknown RequestType: {request_type}") + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. - self.input_queue.put_nowait(request) + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str): """Output socket IO thread.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index cdc63acdb..2d7d6b42c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -5,7 +5,7 @@ import os import signal import weakref from abc import ABC, abstractmethod -from typing import List, Optional, Type +from typing import Any, List, Optional, Type import zmq import zmq.asyncio @@ -14,12 +14,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -161,7 +160,7 @@ class MPClient(EngineCoreClient): signal.signal(signal.SIGUSR1, sigusr1_handler) # Serialization setup. - self.encoder = PickleEncoder() + self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. @@ -220,7 +219,7 @@ class SyncMPClient(MPClient): return self.decoder.decode(frame.buffer) def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -237,12 +236,10 @@ class SyncMPClient(MPClient): self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + self._send_input(EngineCoreRequestType.PROFILE, is_start) def reset_prefix_cache(self) -> None: - self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) class AsyncMPClient(MPClient): @@ -277,7 +274,7 @@ class AsyncMPClient(MPClient): return self.decoder.decode(await self.outputs_queue.get()) async def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -293,9 +290,7 @@ class AsyncMPClient(MPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + await self._send_input(EngineCoreRequestType.PROFILE, is_start) async def reset_prefix_cache_async(self) -> None: - await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a7fba65e7..3f000abcd 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,21 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import pickle -from typing import Any +from typing import Any, Optional import torch from msgspec import msgpack -CUSTOM_TYPE_CODE_PICKLE = 1 - - -class PickleEncoder: - - def encode(self, obj: Any): - return pickle.dumps(obj) - - def decode(self, data: Any): - return pickle.loads(data) +CUSTOM_TYPE_TENSOR = 1 +CUSTOM_TYPE_PICKLE = 2 class MsgpackEncoder: @@ -34,8 +26,9 @@ class MsgpackEncoder: class MsgpackDecoder: """Decoder with custom torch tensor serialization.""" - def __init__(self, t: Any): - self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook) + def __init__(self, t: Optional[Any] = None): + args = () if t is None else (t, ) + self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) def decode(self, obj: Any): return self.decoder.decode(obj) @@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any: # NOTE(rob): it is fastest to use numpy + pickle # when serializing torch tensors. # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 - return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy())) + return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) - raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) def custom_ext_hook(code: int, data: memoryview) -> Any: - if code == CUSTOM_TYPE_CODE_PICKLE: + if code == CUSTOM_TYPE_TENSOR: return torch.from_numpy(pickle.loads(data)) + if code == CUSTOM_TYPE_PICKLE: + return pickle.loads(data) raise NotImplementedError(f"Extension type code {code} is not supported") -- GitLab From 44607e07d3baf297efe56d77b3b1ddfbf16dad88 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Sun, 9 Feb 2025 22:45:07 -0500 Subject: [PATCH 052/253] Check if selected backend is None in get_attn_backend_cls() (#12975) Signed-off-by: Yuan Tang --- vllm/platforms/cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4e0683b8a..179ee6a7d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -35,7 +35,7 @@ class CpuPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str: - if selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Using Torch SDPA backend.") return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" -- GitLab From b2496bb07fdf9318e7d9a8065356941fef380bac Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Feb 2025 13:03:43 +0800 Subject: [PATCH 053/253] [core] fix sleep mode and pytorch checkpoint compatibility (#13001) Signed-off-by: youkaichao --- tests/basic_correctness/test_cumem.py | 10 ++++++++-- vllm/model_executor/model_loader/weight_utils.py | 1 - 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 4e9f1bf1c..3ac948799 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -115,10 +115,16 @@ def test_cumem_with_cudagraph(): @fork_new_process_for_each_test -def test_end_to_end(): +@pytest.mark.parametrize( + "model", + [ + "meta-llama/Llama-3.2-1B", # sleep mode with safetensors + "facebook/opt-125m" # sleep mode with pytorch checkpoint + ]) +def test_end_to_end(model): free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running - llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True) + llm = LLM(model, enable_sleep_mode=True) prompt = "How are you?" sampling_params = SamplingParams(temperature=0, max_tokens=10) output = llm.generate(prompt, sampling_params) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 68ade319d..8b2c5610f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -462,7 +462,6 @@ def pt_weights_iterator( state = torch.load(bin_file, map_location="cpu", weights_only=True) yield from state.items() del state - torch.cuda.empty_cache() def get_gguf_extra_tensor_names( -- GitLab From 243137143c81f738db17cfcd93d991f6dd842e27 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Mon, 10 Feb 2025 01:09:33 -0500 Subject: [PATCH 054/253] [Doc] Add link to tool_choice tracking issue in tool_calling.md (#13003) Signed-off-by: Yuan Tang --- docs/source/features/tool_calling.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 027ddb6d5..85a9e0373 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -1,6 +1,6 @@ # Tool Calling -vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but on the roadmap. +vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but [on the roadmap](gh-issue:13002). ## Quickstart -- GitLab From fde71262e0c235fa5ad80677b3ba65df7f5110de Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 10 Feb 2025 01:15:02 -0800 Subject: [PATCH 055/253] [misc] Add retries with exponential backoff for HF file existence check (#13008) --- vllm/transformers_utils/config.py | 61 ++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 42b45e10e..aade28610 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -3,6 +3,7 @@ import enum import json import os +import time from pathlib import Path from typing import Any, Dict, Literal, Optional, Type, Union @@ -100,15 +101,33 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # NB: file_exists will only check for the existence of the config file on # hf_hub. This will fail in offline mode. - try: - return file_exists(model, - config_name, - revision=revision, - token=HF_TOKEN) - except huggingface_hub.errors.OfflineModeIsEnabled: - # Don't raise in offline mode, all we know is that we don't have this - # file cached. - return False + + # Call HF to check if the file exists + # 2 retries and exponential backoff + max_retries = 2 + retry_delay = 2 + for attempt in range(max_retries): + try: + return file_exists(model, + config_name, + revision=revision, + token=HF_TOKEN) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, + # all we know is that we don't have this + # file cached. + return False + except Exception as e: + logger.error( + "Error checking file existence: %s, retrying %d of %d", e, + attempt + 1, max_retries) + if attempt == max_retries - 1: + logger.error("Error checking file existence: %s", e) + raise + time.sleep(retry_delay) + retry_delay *= 2 + continue + return False def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -193,10 +212,26 @@ def get_config( # raise an offline mode error to indicate to the user that they # don't have files cached and may need to go online. # This is conveniently triggered by calling file_exists(). - file_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=HF_TOKEN) + + # Call HF to check if the file exists + # 2 retries and exponential backoff + max_retries = 2 + retry_delay = 2 + for attempt in range(max_retries): + try: + file_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=HF_TOKEN) + except Exception as e: + logger.error( + "Error checking file existence: %s, retrying %d of %d", + e, attempt + 1, max_retries) + if attempt == max_retries: + logger.error("Error checking file existence: %s", e) + raise e + time.sleep(retry_delay) + retry_delay *= 2 raise ValueError(f"No supported config format found in {model}") -- GitLab From 51f0b5f7f6ec4aa8199f12bb7df08c9cb5e025db Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 10 Feb 2025 18:45:21 +0800 Subject: [PATCH 056/253] [Bugfix] Clean up and fix multi-modal processors (#13012) Signed-off-by: DarkLight1337 --- docs/source/features/compatibility_matrix.md | 2 +- .../decoder_only/language/test_models.py | 10 ++ .../multimodal/processing/test_common.py | 2 +- tests/multimodal/utils.py | 3 - vllm/model_executor/models/chatglm.py | 160 +++++++----------- vllm/model_executor/models/qwen.py | 91 +++++----- vllm/model_executor/models/qwen2_vl.py | 10 +- 7 files changed, 124 insertions(+), 154 deletions(-) diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index b0018ebcc..ee5db70c7 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar * ✅ * ✅ * ? - * [✗](gh-issue:7968>) + * [✗](gh-issue:7968) * ? * ✅ * diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 1ad562415..c6d524431 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -26,6 +26,9 @@ from ...utils import check_logprobs_close "google/gemma-1.1-2b-it", # gemma marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + pytest.param( + "THUDM/chatglm3-6b", # ChatGLM (text-only) + ), pytest.param( "meta-llama/Llama-3.2-1B-Instruct", # llama marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -43,6 +46,9 @@ from ...utils import check_logprobs_close "microsoft/phi-2", # phi marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen-7B", # qwen (text-only) + ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], @@ -68,6 +74,10 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: + if model.startswith("THUDM/chatglm3"): + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.transformer.output_layer + hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8658e60bc..a56a9e2be 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -89,7 +89,7 @@ def _test_processing_correctness( mm_data = { k: [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit))] + for _ in range(rng.randint(limit + 1))] for k, limit in limit_mm_per_prompt.items() } diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 9a336b7e6..40fcfeeea 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -17,10 +17,7 @@ def random_video( min_wh: int, max_wh: int, ): - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 num_frames = rng.randint(min_frames, max_frames) - num_frames = (num_frames // 2) * 2 - w, h = rng.randint(min_wh, max_wh, size=(2, )) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 9ee9e9ca8..153c85cfb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,8 +4,8 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch from torch import nn @@ -19,7 +19,6 @@ from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -37,12 +36,10 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BatchFeature, - BoundPromptReplacement, MultiModalFieldConfig, - PlaceholderFeaturesInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -53,39 +50,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) -logger = init_logger(__name__) - -IMAGE_TOKEN_ID = 151329 - - -def build_normalization_transform(image_size: int) -> transforms.Compose: - """ - Build a normalization transform which can be applied to one or - more input images from which we want to extract visual features. - - Args: - image_size: size of the image to be processed for visual embeddings. - - Returns: - Callable transform for normalizing and resizing one RGB image. - """ - - return transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - (0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711), - ), - ]) - - -def calculate_image_placeholder(vision_config): - return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 - class GLMImagePixelInputs(TypedDict): pixel_values: torch.Tensor @@ -109,9 +73,20 @@ class GLM4VProcessor: self.config = config self.tokenizer = tokenizer - if hasattr(self.config, "vision_config"): - self.image_transform = build_normalization_transform( - config.vision_config["image_size"]) + if vision_config := getattr(config, "vision_config", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) else: self.image_transform = None @@ -150,9 +125,19 @@ class GLM4VProcessor: class GLM4VProcessingInfo(BaseProcessingInfo): - def __init__(self, ctx): - super().__init__(ctx) - self._pre_calculate() + def get_tokenizer(self): + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + return tokenizer + + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} @@ -162,27 +147,21 @@ class GLM4VProcessingInfo(BaseProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: + return {"image": self.get_num_image_feature_tokens()} - return {"image": self.image_token_num + 2} - - def _pre_calculate(self): + def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - self.image_token_num = calculate_image_placeholder(vision_config) - self.image_size = vision_config["image_size"] + if not (vision_config := getattr(hf_config, "vision_config", None)): + return 0 - def get_num_image_tokens(self) -> int: - return self.image_token_num + 2 + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length - def get_image_size(self) -> ImageSize: - - return ImageSize(height=self.image_size, width=self.image_size) - - def get_hf_processor(self) -> GLM4VProcessor: - return GLM4VProcessor( - self.get_hf_config(), - self.get_tokenizer(), - ) + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): @@ -192,8 +171,12 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + if not (vision_config := getattr(hf_config, "vision_config", None)): + return ProcessorInputs(prompt_text="", mm_data={}) + + target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) - target_width, target_height = self.info.get_image_size() mm_data = { "image": @@ -201,9 +184,11 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): height=target_height, num_images=num_images) } - text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + return ProcessorInputs( - prompt_text=text, + prompt_text=base_text * num_images, mm_data=mm_data, ) @@ -223,47 +208,28 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "vision_config"): + return [] + + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id def get_replacement(item_idx: int): - image_tokens = self.info.image_token_num - return [IMAGE_TOKEN_ID] * image_tokens + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] return [ PromptReplacement( modality="image", - target=[IMAGE_TOKEN_ID], + target=[boi_token_id, image_token_id, eoi_token_id], replacement=get_replacement, ), ] - def _apply_prompt_replacements( - self, - token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - token_ids, text, placeholders = super()._apply_prompt_replacements( - token_ids=token_ids, - mm_prompt_repls=mm_prompt_repls, - mm_item_counts=mm_item_counts, - ) - hf_config = self.info.get_hf_config() - boi_token_id = hf_config.boi_token_id - eoi_token_id = hf_config.eoi_token_id - placeholders = { - modality: [ - PlaceholderFeaturesInfo( - modality=p.modality, - item_idx=p.item_idx, - start_idx=p.start_idx - 1, - tokens=[boi_token_id] + p.tokens + [eoi_token_id], - ) for p in ps - ] - for modality, ps in placeholders.items() - } - - return token_ids, text, placeholders - class GLMAttention(nn.Module): @@ -618,7 +584,7 @@ class ChatGLMModel(nn.Module): multimodal_embeddings=multimodal_embeddings, placeholder_token_id=[ self.config.boi_token_id, - IMAGE_TOKEN_ID, + self.config.pad_token_id, self.config.eoi_token_id, ], ) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 897066124..4b8aeaddb 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -63,18 +63,6 @@ from .utils import (flatten_bn, is_pp_missing_parameter, logger = init_logger(__name__) -# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; -# for the time being, these tags are not considered as special at encoding -# time. This may change as VLLMs multimodal API changes in the future. -IMG_START = "" -IMG_END = "" -IMG_PAD = "" -# Image context is fixed at 256 for all images -MAX_QWEN_IMG_TOKENS = 256 -# Image normalization params -CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) -CLIP_STD = (0.26862954, 0.26130258, 0.27577711) - class QwenImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -622,25 +610,6 @@ class QWenModel(nn.Module): return hidden_states -def build_normalization_transform(image_size: int) -> transforms.Compose: - """ - Build a normalization transform which can be applied to one or - more input images from which we want to extract visual features. - - Args: - image_size: size of the image to be processed for visual embeddings. - - Returns: - Callable transform for normalizing and resizing one RGB image. - """ - return transforms.Compose([ - transforms.Resize((image_size, image_size), - interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), - ]) - - @lru_cache(maxsize=1) def _get_tokenizer_without_image_pad( tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -716,16 +685,34 @@ class QWenVLProcessor: self.config = config self.tokenizer = tokenizer - if hasattr(self.config, "visual"): - self.image_transform = build_normalization_transform( - config.visual["image_size"]) + if vision_config := getattr(self.config, "visual", None): + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) else: self.image_transform = None - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore - self.img_start_id = special_tokens[IMG_START] - self.img_end_id = special_tokens[IMG_END] + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore def __call__( self, @@ -787,7 +774,14 @@ class QWenVLProcessingInfo(BaseProcessingInfo): return {"image": self.get_num_image_tokens()} def get_num_image_tokens(self) -> int: - return MAX_QWEN_IMG_TOKENS + hf_config = self.get_hf_config() + if not (vision_config := getattr(hf_config, "visual", None)): + return 0 + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): @@ -798,10 +792,12 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): mm_counts: Mapping[str, int], ) -> ProcessorInputs: hf_config = self.info.get_hf_config() - if not hasattr(hf_config, "visual"): + if not (vision_config := getattr(hf_config, "visual", None)): return ProcessorInputs(prompt_text="", mm_data={}) - vision_config = hf_config.visual + processor = self.info.get_hf_processor() + img_start = processor.image_start_tag + img_end = processor.image_end_tag target_width = target_height = vision_config["image_size"] num_images = mm_counts.get("image", 0) @@ -814,7 +810,7 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): } return ProcessorInputs( - prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n" + prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1)), mm_data=mm_data, ) @@ -869,13 +865,18 @@ class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + if not hasattr(hf_config, "visual"): + return [] + tokenizer = self.info.get_tokenizer() special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore - img_start_id = special_tokens[IMG_START] - img_end_id = special_tokens[IMG_END] - img_pad_id = special_tokens[IMG_PAD] + processor = self.info.get_hf_processor() + img_start_id = special_tokens[processor.image_start_tag] + img_end_id = special_tokens[processor.image_end_tag] + img_pad_id = special_tokens[processor.image_pad_tag] num_image_tokens = self.info.get_num_image_tokens() image_tokens = [img_pad_id] * num_image_tokens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 34ae7b8c9..f2071eaff 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -885,14 +885,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - num_frames = min(max(max_total_frames // max(max_videos, 1), 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) - # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 - if num_frames > 1 and num_frames % 2 == 1: - num_frames += 1 - - return num_frames + return max(max_frames_per_video, 1) def get_max_video_tokens(self, seq_len: int) -> int: target_width, target_height = self.get_image_size_with_most_features() -- GitLab From 2ae889052c6d0205ca677052ddb41db96a2a2620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Mon, 10 Feb 2025 20:56:50 +0530 Subject: [PATCH 057/253] Fix seed parameter behavior in vLLM (#13007) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: மனோஜ்குமார் பழனிச்சாமி --- docs/seed_parameter_behavior.md | 51 +++++++++++++++++++++++++++++++++ tests/test_seed_behavior.py | 39 +++++++++++++++++++++++++ vllm/platforms/interface.py | 9 +++--- 3 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 docs/seed_parameter_behavior.md create mode 100644 tests/test_seed_behavior.py diff --git a/docs/seed_parameter_behavior.md b/docs/seed_parameter_behavior.md new file mode 100644 index 000000000..ff17525cf --- /dev/null +++ b/docs/seed_parameter_behavior.md @@ -0,0 +1,51 @@ +# Seed Parameter Behavior in vLLM + +## Overview + +The `seed` parameter in vLLM is used to control the random states for various random number generators. This parameter can affect the behavior of random operations in user code, especially when working with models in vLLM. + +## Default Behavior + +By default, the `seed` parameter is set to `None`. When the `seed` parameter is `None`, the global random states for `random`, `np.random`, and `torch.manual_seed` are not set. This means that the random operations will behave as expected, without any fixed random states. + +## Specifying a Seed + +If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set accordingly. This can be useful for reproducibility, as it ensures that the random operations produce the same results across multiple runs. + +## Example Usage + +### Without Specifying a Seed + +```python +import random +from vllm import LLM + +# Initialize a vLLM model without specifying a seed +model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct") + +# Try generating random numbers +print(random.randint(0, 100)) # Outputs different numbers across runs +``` + +### Specifying a Seed + +```python +import random +from vllm import LLM + +# Initialize a vLLM model with a specific seed +model = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", seed=42) + +# Try generating random numbers +print(random.randint(0, 100)) # Outputs the same number across runs +``` + +## Important Notes + +- If the `seed` parameter is not specified, the behavior of global random states remains unaffected. +- If a specific seed value is provided, the global random states for `random`, `np.random`, and `torch.manual_seed` will be set to that value. +- This behavior can be useful for reproducibility but may lead to non-intuitive behavior if the user is not explicitly aware of it. + +## Conclusion + +Understanding the behavior of the `seed` parameter in vLLM is crucial for ensuring the expected behavior of random operations in your code. By default, the `seed` parameter is set to `None`, which means that the global random states are not affected. However, specifying a seed value can help achieve reproducibility in your experiments. diff --git a/tests/test_seed_behavior.py b/tests/test_seed_behavior.py new file mode 100644 index 000000000..7e4e71563 --- /dev/null +++ b/tests/test_seed_behavior.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +import random + +import numpy as np +import torch + +from vllm.platforms.interface import Platform + + +def test_seed_behavior(): + # Test with seed=None + Platform.seed_everything(None) + random_value_1 = random.randint(0, 100) + np_random_value_1 = np.random.randint(0, 100) + torch_random_value_1 = torch.randint(0, 100, (1, )).item() + + Platform.seed_everything(None) + random_value_2 = random.randint(0, 100) + np_random_value_2 = np.random.randint(0, 100) + torch_random_value_2 = torch.randint(0, 100, (1, )).item() + + assert random_value_1 != random_value_2 + assert np_random_value_1 != np_random_value_2 + assert torch_random_value_1 != torch_random_value_2 + + # Test with a specific seed + Platform.seed_everything(42) + random_value_3 = random.randint(0, 100) + np_random_value_3 = np.random.randint(0, 100) + torch_random_value_3 = torch.randint(0, 100, (1, )).item() + + Platform.seed_everything(42) + random_value_4 = random.randint(0, 100) + np_random_value_4 = np.random.randint(0, 100) + torch_random_value_4 = torch.randint(0, 100, (1, )).item() + + assert random_value_3 == random_value_4 + assert np_random_value_3 == np_random_value_4 + assert torch_random_value_3 == torch_random_value_4 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 211e288b1..645d98a1b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -211,16 +211,17 @@ class Platform: return torch.inference_mode(mode=True) @classmethod - def seed_everything(cls, seed: int) -> None: + def seed_everything(cls, seed: Optional[int] = None) -> None: """ Set the seed of each random module. `torch.manual_seed` will set seed on all devices. Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: -- GitLab From 08b2d845d6261309bfdb46933f872eebe4e2bb31 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Feb 2025 14:02:48 -0800 Subject: [PATCH 058/253] [Model] Ultravox Model: Support v0.5 Release (#12912) Signed-off-by: Farzad Abdolhosseini --- docs/source/models/supported_models.md | 2 +- docs/source/serving/multimodal_inputs.md | 4 +-- examples/offline_inference/audio_language.py | 4 +-- ...i_chat_completion_client_for_multimodal.py | 2 +- tests/distributed/test_pipeline_parallel.py | 4 +-- tests/entrypoints/openai/test_audio.py | 2 +- tests/entrypoints/test_chat_utils.py | 2 +- .../audio_language/test_ultravox.py | 2 +- .../multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 2 +- vllm/model_executor/models/ultravox.py | 26 ++++++++++++------- vllm/transformers_utils/configs/ultravox.py | 6 +++++ 12 files changed, 36 insertions(+), 22 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 91e6c42d5..55b3f5235 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `UltravoxModel` * Ultravox * T + AE+ - * `fixie-ai/ultravox-v0_3` + * `fixie-ai/ultravox-v0_5-llama-3_2-1b` * ✅︎ * ✅︎ * ✅︎ diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index 217b531e8..ade59e377 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT= ### Audio Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in). -Here is a simple example using Ultravox-v0.3. +Here is a simple example using Ultravox-v0.5-1B. First, launch the OpenAI-compatible server: ```bash -vllm serve fixie-ai/ultravox-v0_3 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b ``` Then, you can use the OpenAI client as follows: diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 707ca9f87..3e3034a02 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -24,9 +24,9 @@ question_per_audio_count = { # Unless specified, these settings have been tested to work on a single L4. -# Ultravox 0.3 +# Ultravox 0.5-1B def run_ultravox(question: str, audio_count: int): - model_name = "fixie-ai/ultravox-v0_3" + model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index d5f798a8d..ecfcf05a9 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -12,7 +12,7 @@ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 """ import base64 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5b6741d74..5d7cb9e40 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -215,7 +215,7 @@ MULTIMODAL_MODELS = { "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), - "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True), + "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 # [Encoder-decoder] # TODO: Implement PP # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), @@ -234,7 +234,7 @@ TEST_MODELS = [ # [MULTIMODAL GENERATION] "OpenGVLab/InternVL2-1B", "microsoft/Phi-3-vision-128k-instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", # [LANGUAGE GENERATION - HYBRID ARCH] "ai21labs/Jamba-tiny-dev", ] diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 3459f2483..fe7299a48 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -11,7 +11,7 @@ from vllm.multimodal.utils import encode_audio_base64, fetch_audio from ...utils import RemoteOpenAIServer -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" TEST_AUDIO_URLS = [ AudioAsset("winning_call").url, ] diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5c469007a..c52fa905c 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -21,7 +21,7 @@ from ..utils import VLLM_PATH EXAMPLES_DIR = VLLM_PATH / "examples" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index fe9361d12..d1f643a8f 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_3" +MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" AudioTuple = Tuple[np.ndarray, int] diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a56a9e2be..6244056c7 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -164,7 +164,7 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3fd94b89c..66b7d3c2e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -267,7 +267,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 min_transformers_version="4.49"), # noqa: E501 - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3", + "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", trust_remote_code=True), # [Encoder-decoder] "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9da0682cf..063997a14 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -258,27 +258,35 @@ class UltravoxProjector(nn.Module): super().__init__() self.hidden_dim = config.hidden_size self._pad_and_stack = StackAudioFrames(config.stack_factor) - dim = config.audio_config.hidden_size * config.stack_factor - self.ln_pre = RMSNorm(dim) - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) - dim = self.hidden_dim + dim_in = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim_in) + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) + dim_mid = self.hidden_dim if config.projector_act == "swiglu": self.act = MulAndSilu() - dim = dim // 2 + dim_mid = dim_mid // 2 else: self.act = get_act_fn(config.projector_act) - self.linear_2 = nn.Linear(dim, - config.text_config.hidden_size, - bias=False) - self.ln_post = RMSNorm(config.text_config.hidden_size) + dim_out = config.text_config.hidden_size + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) + + # Ultravox v0.4.1 and below use layer_norm after the second linear layer + # while v0.5.0 and above uses layer_norm after the first linear layer. + if config.projector_ln_mid: + self.ln_mid: nn.Module = RMSNorm(dim_mid) + self.ln_post = nn.Identity() + else: + self.ln_mid = nn.Identity() + self.ln_post = RMSNorm(dim_out) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) + hidden_states = self.ln_mid(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_post(hidden_states) return hidden_states diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 99715ba6d..6b2765db9 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig): The LoRA configuration for finetuning the text model. audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. + projector_ln_mid (`bool`, *optional*, defaults to `False`): + Whether to apply layer normalization at the middle of the + projector or at the end. Versions v0.4.1 and below + use `False`, but v0.5 and above use `True`. """ model_type = "ultravox" @@ -56,6 +60,7 @@ class UltravoxConfig(transformers.PretrainedConfig): projector_act: str = "swiglu", text_model_lora_config: Optional[Dict[str, Any]] = None, audio_model_lora_config: Optional[Dict[str, Any]] = None, + projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index @@ -68,6 +73,7 @@ class UltravoxConfig(transformers.PretrainedConfig): self.stack_factor = stack_factor self.norm_init = norm_init self.projector_act = projector_act + self.projector_ln_mid = projector_ln_mid if text_model_id is not None: # Avoid circular import -- GitLab From 91e876750eace8e899ab25cd5d93fc365906c07b Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 10 Feb 2025 18:06:16 -0800 Subject: [PATCH 059/253] [misc] Fix setup.py condition to avoid AMD from being mistaken with CPU (#13022) Signed-off-by: kevin --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 3e2adadf6..27e5aab76 100755 --- a/setup.py +++ b/setup.py @@ -48,8 +48,9 @@ elif not (sys.platform.startswith("linux") "so vLLM may not be able to run correctly", sys.platform) VLLM_TARGET_DEVICE = "empty" elif (sys.platform.startswith("linux") and torch.version.cuda is None - and os.getenv("VLLM_TARGET_DEVICE") is None): - # if cuda is not available and VLLM_TARGET_DEVICE is not set, + and os.getenv("VLLM_TARGET_DEVICE") is None + and torch.version.hip is None): + # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu VLLM_TARGET_DEVICE = "cpu" -- GitLab From 2ff4857678044407a959398178a7a04a9530919a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 10 Feb 2025 18:10:06 -0800 Subject: [PATCH 060/253] [V1][Minor] Move scheduler outputs to a separate file (#13062) Signed-off-by: Woosuk Kwon --- vllm/v1/core/scheduler.py | 89 +----------------------- vllm/v1/core/scheduler_output.py | 108 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/v1/worker/gpu_worker.py | 3 +- 4 files changed, 113 insertions(+), 89 deletions(-) create mode 100644 vllm/v1/core/scheduler_output.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 1aa34ee38..1c54914d1 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,26 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.base import PlaceholderRange - logger = init_logger(__name__) @@ -600,80 +594,3 @@ class Scheduler: num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, ) - - -@dataclass -class NewRequestData: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List["MultiModalKwargs"] - mm_hashes: List[str] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - block_ids: List[int] - num_computed_tokens: int - lora_request: Optional[LoRARequest] - - @classmethod - def from_request( - cls, - request: Request, - block_ids: List[int], - num_computed_tokens: int, - ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - lora_request=request.lora_request, - ) - - -@dataclass -class CachedRequestData: - - req_id: str - # If resumed_from_preemption is False, new_block_ids will be appended to - # the request's block IDs. If True, new_block_ids will be used as the - # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: bool - new_block_ids: List[int] - num_computed_tokens: int - - @classmethod - def from_request( - cls, - request: Request, - resumed_from_preemption: bool, - new_block_ids: List[int], - num_computed_tokens: int, - ) -> "CachedRequestData": - return cls( - req_id=request.request_id, - resumed_from_preemption=resumed_from_preemption, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) - - -@dataclass -class SchedulerOutput: - - scheduled_new_reqs: List[NewRequestData] - scheduled_cached_reqs: List[CachedRequestData] - - num_scheduled_tokens: Dict[str, int] - total_num_scheduled_tokens: int - scheduled_encoder_inputs: Dict[str, List[int]] - num_common_prefix_blocks: int - - finished_req_ids: Set[str] - free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py new file mode 100644 index 000000000..990b3dd0e --- /dev/null +++ b/vllm/v1/core/scheduler_output.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +if TYPE_CHECKING: + from vllm.lora.request import LoRARequest + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.base import PlaceholderRange + from vllm.sampling_params import SamplingParams + from vllm.v1.request import Request + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] + mm_positions: List["PlaceholderRange"] + sampling_params: "SamplingParams" + block_ids: List[int] + num_computed_tokens: int + lora_request: Optional["LoRARequest"] + + @classmethod + def from_request( + cls, + request: "Request", + block_ids: List[int], + num_computed_tokens: int, + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, + ) + + +@dataclass +class CachedRequestData: + + req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: "Request", + resumed_from_preemption: bool, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "CachedRequestData": + return cls( + req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + # List of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: List[NewRequestData] + # List of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: List[CachedRequestData] + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: Dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: Dict[str, List[int]] + # Number of common prefix blocks for all requests. + # This can be used for cascade attention. + num_common_prefix_blocks: int + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: Set[str] + # List of (req_id, encoder_input_index) tuples. + # Used to free the encoder cache. + free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdbca70bd..9b1eab613 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,7 +36,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073..ad53f90b8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -18,7 +18,6 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils import GiB_bytes -from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -26,7 +25,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput class Worker: -- GitLab From 2c0f58203c111bcc331f931664400acfc94cb9bc Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 10 Feb 2025 18:24:29 -0800 Subject: [PATCH 061/253] [Docs] Annouce Meta Meetup (#13065) Signed-off-by: simon-mo --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index f04acf09c..f22a1f9c5 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ Easy, fast, and cheap LLM serving for everyone --- +We are excited to invite you to our Menlo Park meetup with Meta, evening of Thursday, February 27! Meta engineers will discuss the improvements on top of vLLM, and vLLM contributors will share updates from the v0.7.x series of releases. [Register Now](https://lu.ma/h7g3kuj9) + +--- + *Latest News* 🔥 - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). -- GitLab From cb080f32e38e87beda897d0602bf6a0d0c79d00f Mon Sep 17 00:00:00 2001 From: Florian Greinacher Date: Tue, 11 Feb 2025 04:33:33 +0100 Subject: [PATCH 062/253] [Bugfix] Support missing tool parameters in mistral tokenizer (#12884) Signed-off-by: Florian Greinacher --- tests/tokenization/test_mistral_tokenizer.py | 50 ++++++++++++++++ vllm/transformers_utils/tokenizers/mistral.py | 57 ++++++++++++------- 2 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 tests/tokenization/test_mistral_tokenizer.py diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py new file mode 100644 index 000000000..03e1f1fad --- /dev/null +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from mistral_common.protocol.instruct.messages import UserMessage +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.tool_calls import Function, Tool + +from vllm.transformers_utils.tokenizers.mistral import ( + make_mistral_chat_completion_request) + + +# yapf: enable +@pytest.mark.parametrize( + "openai_request,expected_mistral_request", + [( + { + "messages": [{ + "role": "user", + "content": "What is the current local date and time?", + }], + "tools": [{ + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), + )], +) +def test_make_mistral_chat_completion_request(openai_request, + expected_mistral_request): + assert (make_mistral_chat_completion_request( + openai_request["messages"], + openai_request["tools"]) == expected_mistral_request) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 8d96fcd27..f08923e74 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]): return matched_files[0] +def make_mistral_chat_completion_request( + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, + Any]]] = None) -> "ChatCompletionRequest": + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [ + tool["function"] for tool in tools + if tool["type"] == "function" + ]: + function.setdefault("parameters", {}) + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + return ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] + + class MistralTokenizer: def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: @@ -283,27 +319,10 @@ class MistralTokenizer: def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], - tools: Optional[Dict[str, Any]] = None, + tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> List[int]: - last_message = cast(Dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True - - from mistral_common.protocol.instruct.request import ( - ChatCompletionRequest) - - # mistral-common requires AssistantMessage content to be string [1]. - # - # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 - for message in messages: - if message.get("role") == "assistant": - content = message.get("content") - if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content - request = ChatCompletionRequest(messages=messages, - tools=tools) # type: ignore[type-var] + request = make_mistral_chat_completion_request(messages, tools) encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt -- GitLab From 58047c6f0410fc7a86b64c88c092a246984b2342 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 10 Feb 2025 21:25:30 -0800 Subject: [PATCH 063/253] [Benchmark] Add BurstGPT to benchmark_serving (#13063) Signed-off-by: Woosuk Kwon Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- benchmarks/README.md | 8 +++++++ benchmarks/benchmark_serving.py | 40 ++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 890a2525b..367ef9345 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -19,3 +19,11 @@ mkdir coco -p wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip unzip coco/train2017.zip -d coco/ ``` + +# Downloading the BurstGPT dataset + +You can download the BurstGPT v1.1 dataset by running: + +```bash +wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv +``` diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 1044bef59..0c8923842 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -38,6 +38,7 @@ from datetime import datetime from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np +import pandas as pd from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) from datasets import load_dataset @@ -131,6 +132,35 @@ def sample_sharegpt_requests( return filtered_dataset +def sample_burstgpt_requests( + dataset_path: str, + num_requests: int, + random_seed: int, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int, None]]: + df = pd.read_csv(dataset_path) + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove the failed requests (i.e., response length is 0) + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Randomly sample num_requests from the dataset + if num_requests <= len(gpt4_df): + gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed) + else: + gpt4_df = gpt4_df.sample(n=num_requests, + random_state=random_seed, + replace=True) + # Convert the dataframe to a list of tuples + dataset = gpt4_df.values.tolist() + input_requests = [] + for i in range(num_requests): + input_len = int(dataset[i][2]) + output_len = int(dataset[i][3]) + prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size + for j in range(input_len)]) + input_requests.append((prompt, input_len, output_len, None)) + return input_requests + + def sample_sonnet_requests( dataset_path: str, num_requests: int, @@ -830,6 +860,14 @@ def main(args: argparse.Namespace): fixed_output_len=args.sharegpt_output_len, ) + elif args.dataset_name == "burstgpt": + input_requests = sample_burstgpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + random_seed=args.seed, + tokenizer=tokenizer, + ) + elif args.dataset_name == "sonnet": # Do not format the prompt, pass to message directly if args.backend == "openai-chat": @@ -995,7 +1033,7 @@ if __name__ == "__main__": "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", -- GitLab From c320ca8edd5c4c19e7581703e428dd566b068756 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 11 Feb 2025 02:25:25 -0500 Subject: [PATCH 064/253] [Core] Don't do platform detection at import time (#12933) Signed-off-by: Russell Bryant --- vllm/executor/executor_base.py | 6 +++--- vllm/executor/ray_utils.py | 6 +++--- vllm/platforms/cuda.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fb76276bb..242690f8e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -8,11 +8,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, import torch.nn as nn from typing_extensions import TypeVar +import vllm.platforms from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async @@ -108,8 +108,8 @@ class ExecutorBase(ABC): """ # NOTE: This is logged in the executor because there can be >1 workers. logger.info("# %s blocks: %d, # CPU blocks: %d", - current_platform.dispatch_key, num_gpu_blocks, - num_cpu_blocks) + vllm.platforms.current_platform.dispatch_key, + num_gpu_blocks, num_cpu_blocks) max_concurrency = (num_gpu_blocks * self.cache_config.block_size / self.model_config.max_model_len) logger.info("Maximum concurrency for %s tokens per request: %.2fx", diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7b3015597..33c0a2580 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import msgspec +import vllm.platforms from vllm.config import ParallelConfig from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase @@ -54,10 +54,10 @@ try: def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: node_id = ray.get_runtime_context().get_node_id() - device_key = current_platform.ray_device_key + device_key = vllm.platforms.current_platform.ray_device_key if not device_key: raise RuntimeError("current platform %s does not support ray.", - current_platform.device_name) + vllm.platforms.current_platform.device_name) gpu_ids = ray.get_runtime_context().get_accelerator_ids( )[device_key] return node_id, gpu_ids diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 991d55ac8..9deb02946 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -334,10 +334,10 @@ class NvmlCudaPlatform(CudaPlatformBase): if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): logger.warning( - "Detected different devices in the system: \n%s\nPlease" + "Detected different devices in the system: %s. Please" " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " "avoid unexpected behavior.", - "\n".join(device_names), + ", ".join(device_names), ) -- GitLab From 78a141d768a18edc8c598a57d992e6aa56a33259 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 11 Feb 2025 12:56:03 +0530 Subject: [PATCH 065/253] [Misc] LoRA - Refactor Punica ops tests (#12970) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/lora/test_punica_ops.py | 652 ++++++++++++++++++++++++ tests/lora/test_punica_ops_sizes.py | 401 --------------- tests/lora/test_punica_ops_variation.py | 317 ------------ tests/lora/utils.py | 41 +- 4 files changed, 686 insertions(+), 725 deletions(-) create mode 100644 tests/lora/test_punica_ops.py delete mode 100644 tests/lora/test_punica_ops_sizes.py delete mode 100644 tests/lora/test_punica_ops_variation.py diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py new file mode 100644 index 000000000..032e20470 --- /dev/null +++ b/tests/lora/test_punica_ops.py @@ -0,0 +1,652 @@ +# SPDX-License-Identifier: Apache-2.0 +from threading import Lock +from typing import List + +import pytest +import torch + +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform + +from .utils import (PunicaTensors, assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices) + + +# Utility shrink and expand operations used as reference implementations. +def sgmv_shrink_for_nslices( + nslices: int, inputs_tensor: torch.Tensor, + lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, + num_tokens: int, scaling: float): + """ + Wrapper around sgmv_shrink that handles any nslices. + """ + for index in range(nslices): + sgmv_shrink( + inputs_tensor, + lora_weights_lst[index], + out_tensor[index], + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + scaling, + ) + + +def sgmv_expand_for_nslices(nslices: int, hidden_size: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: List[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, batches: int, + max_seq_length: int, num_tokens: int, + add_inputs: bool) -> None: + """ + Wrapper around sgmv_expand that handles any nslices. + """ + if nslices == 1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + add_inputs=add_inputs, + ) + else: + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + out_tensor, + b_seq_start_loc, + seq_len_tensor, + prompt_lora_mapping, + batches, + max_seq_length, + num_tokens, + slice_offset, + hidden_size, + add_inputs=add_inputs, + ) + slice_offset += hidden_size + + +_dict_lock = Lock() + + +def check_sgmv_shrink(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, scaling: float): + """ + Compare outputs of vllm.sgmv_shrink kernel against a reference + implementation. + """ + data: PunicaTensors = generate_data_for_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + "shrink", + device, + ) + max_seq_length, token_nums = data.meta() + + # Preventing cache error pointer. + with _dict_lock: + _LORA_A_PTR_DICT.clear() + torch.ops.vllm.sgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + scaling, + ) + + sgmv_shrink_for_nslices( + nslices, + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + scaling, + ) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_sgmv_expand(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, add_inputs: bool): + """ + Compare outputs of vllm.sgmv_expand kernel against a reference + implementation. + """ + data: PunicaTensors = generate_data_for_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + nslices, + dtype, + "expand", + device, + ) + + max_seq_length, token_nums = data.meta() + + with _dict_lock: + _LORA_B_PTR_DICT.clear() + torch.ops.vllm.sgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=add_inputs, + ) + + sgmv_expand_for_nslices(nslices, + hidden_size, + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + add_inputs=add_inputs) + + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_shrink(batches: int, num_loras: int, rank: int, + hidden_size: int, dtype: torch.dtype, device: str, + scaling: float): + """ + Compare vllm.bgmv_shrink against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "shrink", + device, + ) + + torch.ops.vllm.bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + scaling, + ) + + bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + scaling, + ) + + data.ref_out_tensor = data.ref_out_tensor.to(torch.float32) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_expand(batches: int, num_loras: int, rank: int, + hidden_size: int, dtype: torch.dtype, device: str, + add_inputs: bool): + """ + Compare vllm.bgmv_expand against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "expand", + device, + ) + + torch.ops.vllm.bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, add_inputs: bool): + """ + Compare vllm.bgmv_expand_slice against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + + slice_offset = 0 + for index in range(nslices): + torch.ops.vllm.bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.our_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.ref_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + + slice_offset += hidden_size + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +# Tests +# We test the punica kernels along 2 verticals mainly. +# 1. Variations in hidden_dim size +# 2. Variations in all other parameters like (batch_size, max_rank, num_loras +# etc.) + +# We have collected the hidden_sizes included in the LoRA models +# currently supported by vLLM. It tests whether the corresponding Triton +# kernel can run normally when tensor parallelism is set to +# [1, 2, 4, 8, 16, 32, 64]. +HIDDEN_SIZES = [ + 128, + 256, + 512, + 896, + 1024, + 1152, + 1216, + 1280, + 1536, + 1664, + 2048, + 2240, + 2304, + 2368, + 2432, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 3712, + 4096, + 4480, + 4608, + 4736, + 4864, + 5120, + 5504, + 5632, + 5888, + 6144, + 6400, + 6848, + 6912, + 7168, + 7424, + 8192, + 8960, + 9216, + 9472, + 10240, + 11008, + 11264, + 13824, + 14336, + 14784, + 14848, + 15360, + 18944, + 22016, + 22528, + 24576, + 27392, + 27648, + 29568, + 29696, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 49408, + 60544, + 60672, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, +] +#The size of TP +divisibility = [1, 2, 8, 16, 64] + +all_hidden_size = [] +for div in divisibility: + for hidden_size in HIDDEN_SIZES: + all_hidden_size.append(hidden_size // div) + +HIDDEN_SIZES = list(set(all_hidden_size)) + +# Test params that focuses on hidden_size variation. +hs_test_params = { + "hidden_sizes": HIDDEN_SIZES, + "batches": [4], + "num_loras": [4], + "max_ranks": [32], +} + +# General tests params that tests for variations in all dimensions +# except hidden_size. +test_params = { + "hidden_sizes": [2049], + "batches": [1, 4, 16, 32], + "num_loras": [1, 8, 32, 128], + "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256], +} + +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = [f"cuda:{0}"] +SEED = [0] + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_sgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) + else: + check_sgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [1, 2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_sgmv_hidden_size( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_sgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) + else: + check_sgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_bgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + scaling=0.5) + else: + check_bgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +def test_punica_bgmv_hidden_size( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + if op_type == "shrink": + check_bgmv_shrink(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + scaling=0.5) + else: + check_bgmv_expand(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", test_params['batches']) +@pytest.mark.parametrize("num_loras", test_params['num_loras']) +@pytest.mark.parametrize("rank", test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, + dtype: torch.dtype, device: str, + seed: int): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + check_bgmv_expand_slice(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + add_inputs=True) + + +@pytest.mark.parametrize("batches", hs_test_params['batches']) +@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) +@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) +@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int, + rank: int, hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, seed: int): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + check_bgmv_expand_slice(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + add_inputs=True) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py deleted file mode 100644 index ecd3bc497..000000000 --- a/tests/lora/test_punica_ops_sizes.py +++ /dev/null @@ -1,401 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This script is mainly used to tests various hidden_sizes. We have collected the -hidden_sizes included in the LoRA models currently supported by vLLM. It tests -whether the corresponding Triton kernel can run normally when tensor parallelism -is set to [1, 2, 4, 8, 16, 32, 64]. -""" -from threading import Lock - -import pytest -import torch - -import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform - -from .utils import (assert_close, generate_data, - generate_data_for_expand_nslices, - generate_data_for_nslices) - -HIDDEN_SIZES = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -#The size of TP -divisibility = [1, 2, 8, 16, 64] - -all_hidden_size = [] -for div in divisibility: - for hidden_size in HIDDEN_SIZES: - all_hidden_size.append(hidden_size // div) - -HIDDEN_SIZES = list(set(all_hidden_size)) - -BATCHES = [4] -NUM_LORA = [4] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [32] -SCALES = [0.5] -SEED = [0] -DEVICES = [f"cuda:{0}"] - -_dict_lock = Lock() - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("nslices", [1, 2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - nslices, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - # Preventing cache error pointer. - with _dict_lock: - _LORA_A_PTR_DICT.clear() - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - for index in range(nslices): - sgmv_shrink( - inputs_tensor, - lora_weights_lst[index], - ref_out_tensor[index], - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - - else: - with _dict_lock: - _LORA_B_PTR_DICT.clear() - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) - if nslices == 1: - # Verify the torch's sgmv_expand op - sgmv_expand( - inputs_tensor[0], - lora_weights_lst[0], - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - else: - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - sgmv_expand_slice( - inputs_tensor[index], - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - slice_offset += hidden_size - - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - torch.ops.vllm.bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - - bgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - scaling, - ) - - else: - torch.ops.vllm.bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - bgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - add_inputs=True, - ) - - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py deleted file mode 100644 index 6d1d3c943..000000000 --- a/tests/lora/test_punica_ops_variation.py +++ /dev/null @@ -1,317 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This script is mainly used to test whether trtion kernels can run normally -under different conditions, including various batches, numbers of LoRA , and -maximum ranks. -""" -from threading import Lock - -import pytest -import torch - -# Enable custom op register -import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform - -from .utils import (assert_close, generate_data, - generate_data_for_expand_nslices, - generate_data_for_nslices) - -HIDDEN_SIZES = [2049] - -BATCHES = [1, 4, 16, 32] -NUM_LORA = [1, 8, 32, 128] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] -SCALES = [0.5] -SEED = [0] -DEVICES = [f"cuda:{0}"] - -_dict_lock = Lock() - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("nslices", [1, 2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - nslices, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - # Preventing cache error pointer. - with _dict_lock: - _LORA_A_PTR_DICT.clear() - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - for index in range(nslices): - sgmv_shrink( - inputs_tensor, - lora_weights_lst[index], - ref_out_tensor[index], - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - - else: - with _dict_lock: - _LORA_B_PTR_DICT.clear() - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) - slice_offset = 0 - if nslices == 1: - # Verify the torch's sgmv_expand op - sgmv_expand( - inputs_tensor[0], - lora_weights_lst[0], - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - else: - for index in range(nslices): - lora_weights = lora_weights_lst[index] - sgmv_expand_slice( - inputs_tensor[index], - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - slice_offset += hidden_size - - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - torch.ops.vllm.bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - - bgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - scaling, - ) - - else: - torch.ops.vllm.bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - bgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - indices, - add_inputs=True, - ) - - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index bda00e081..1e163fbf9 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union import torch @@ -106,6 +107,31 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +@dataclass +class PunicaTensors: + inputs_tensor: torch.Tensor + lora_weights: Union[torch.Tensor, List[torch.Tensor]] + our_out_tensor: torch.Tensor + ref_out_tensor: torch.Tensor + b_seq_start_loc: torch.Tensor + prompt_lora_mapping: torch.Tensor + seq_len_tensor: torch.Tensor + token_lora_mapping: torch.Tensor + + def meta(self) -> Tuple[int, int]: + """ + Infer max_seq_length and token_nums from the tensors + and return them. + """ + max_seq_length = self.seq_len_tensor.max() + token_nums = self.seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + return max_seq_length, token_nums + + def generate_data( batches, hidden_size, @@ -115,7 +141,7 @@ def generate_data( dtype, op_type, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -164,7 +190,8 @@ def generate_data( indices[current_offset:current_offset + seq_len_tensor[b_id]].copy_(lora_index) current_offset += seq_len_tensor[b_id].item() - return ( + + return PunicaTensors( inputs_tensor, lora_weights, our_out_tensor, @@ -185,7 +212,7 @@ def generate_data_for_expand_nslices( dtype, nslices, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -222,7 +249,7 @@ def generate_data_for_expand_nslices( current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) - return ( + return PunicaTensors( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -244,7 +271,7 @@ def generate_data_for_nslices( dtype, op_type, device, -): +) -> PunicaTensors: seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -302,7 +329,7 @@ def generate_data_for_nslices( current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) - return ( + return PunicaTensors( inputs_tensor, lora_weights_lst, our_out_tensor, -- GitLab From fc6485d27750076642e99a1ef2df0e6375958bb4 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Tue, 11 Feb 2025 15:49:03 +0800 Subject: [PATCH 066/253] [Bugfix]: Reasoning output bug according to the chat template change (#13025) Signed-off-by: Ce Gao --- .../openai_chat_completion_with_reasoning.py | 8 +- .../test_deepseekr1_reasoning_parser.py | 108 +++++++++++++++--- .../deepseek_r1_reasoning_parser.py | 58 ++++++---- 3 files changed, 129 insertions(+), 45 deletions(-) diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index a88c8adb5..b5dbed120 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -36,8 +36,8 @@ response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 1:", reasoning_content) +print("content for Round 1:", content) # Round 2 messages.append({"role": "assistant", "content": content}) @@ -50,5 +50,5 @@ response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 2:", reasoning_content) +print("content for Round 2:", content) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index f7b81be48..fdadb2e21 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -15,32 +15,62 @@ start_token = "" end_token = "" SIMPLE_REASONING = { - "output": "This is a reasoning sectionThis is the rest", + "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", } COMPLETE_REASONING = { - "output": "This is a reasoning section", + "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, } NO_REASONING = { - "output": "This is a reasoning section", + "output": "This is content", "reasoning_content": None, - "content": "This is a reasoning section", + "content": "This is content", +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, } MULTIPLE_LINES = { - "output": "This\nThatThis is the rest\nThat", + "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } SHORTEST_REASONING_NO_STREAMING = { - "output": "This is the rest", + "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", } SHORTEST_REASONING = { - "output": "This is the rest", + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} +REASONING_WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", +} +SHORTEST_REASONING_WITH_THINK = { + "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", } @@ -49,37 +79,37 @@ TEST_CASES = [ pytest.param( False, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning", ), pytest.param( True, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning_streaming", ), pytest.param( False, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning", ), pytest.param( True, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning_streaming", ), pytest.param( False, NO_REASONING, - id="no_streaming", + id="no_reasoning_token", ), pytest.param( True, - NO_REASONING, - id="no_streaming", + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", ), pytest.param( False, MULTIPLE_LINES, - id="multiple_lines_streaming", + id="multiple_lines", ), pytest.param( True, @@ -89,23 +119,65 @@ TEST_CASES = [ pytest.param( True, SHORTEST_REASONING, - id="shortest_streaming", + id="shortest", ), pytest.param( False, SHORTEST_REASONING_NO_STREAMING, id="shortest_streaming", ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), ] +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") +tokenizer.add_tokens([start_token, end_token]) + @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) def test_reasoning( streaming: bool, param_dict: dict, ): - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - tokenizer.add_tokens([start_token, end_token]) output = tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: List[str] = [ diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index 5c19888d4..33bba0488 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ]): return None + # Check if is present in previous or delta. + # Keep compatibility with models that don't generate tokens. if self.think_start_token_id in previous_token_ids: if self.think_end_token_id in delta_token_ids: # in previous, in delta, @@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser): # reasoning content continues return DeltaMessage(reasoning_content=delta_text) elif self.think_start_token_id in delta_token_ids: - logger.info(delta_text) if self.think_end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) @@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser): # reasoning content continues return DeltaMessage(reasoning_content=delta_text) else: - # No in previous or delta, reasoning content continues. - return DeltaMessage(content=delta_text) + # No in previous or delta, also need to check for . + # Because the model may have generated without + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, thinking content ends + return DeltaMessage(content=delta_text) + else: + # no in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: - # Check if the model output contains the tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # DeepSeek R1 doesn't generate now. + # Thus we assume the reasoning content is always at the start. + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token not in model_output: return None, model_output else: + # Add a start token if it's missing to keep compatibility. + if self.think_start_token not in model_output: + model_output = f"{self.think_start_token}{model_output}" # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] - # Remove the reasoning content from the model output - # Although deepseek's token is always at the - # beginning of the line, we cannot guarantee that the - # other models will follow this convention. - # Therefore, we need to add :start_index. - start_index = model_output.find(self.think_start_token) - if start_index != -1: - end_index = start_index + len( - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" - ) - model_output = model_output[:start_index] + \ - model_output[end_index:] - - if len(model_output) == 0: - return reasoning_content, None - - return reasoning_content, model_output + end_index = len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + final_output = model_output[end_index:] + + if len(final_output) == 0: + return reasoning_content, None + + return reasoning_content, final_output -- GitLab From 41c5dd45b98d5a6facad328a1ce534b9a94763a2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 11 Feb 2025 00:27:25 -0800 Subject: [PATCH 067/253] [V1][Metrics] Add GPU prefix cache hit rate % gauge (#12592) --- tests/entrypoints/openai/test_metrics.py | 2 + tests/v1/core/test_kv_cache_utils.py | 39 ++++++++++++++- vllm/v1/core/kv_cache_manager.py | 24 +++++++++ vllm/v1/core/kv_cache_utils.py | 64 ++++++++++++++++++++++++ vllm/v1/core/scheduler.py | 1 + vllm/v1/metrics/loggers.py | 29 ++++++++++- vllm/v1/metrics/stats.py | 20 +++++++- 7 files changed, 174 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index de2333901..8c1bb1a89 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -203,6 +203,8 @@ EXPECTED_METRICS_V1 = [ "vllm:num_requests_running", "vllm:num_requests_waiting", "vllm:gpu_cache_usage_perc", + "vllm:gpu_prefix_cache_queries", + "vllm:gpu_prefix_cache_hits", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", "vllm:request_success_total", diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 8df4cbe1b..ba08b83ec 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -5,10 +5,11 @@ import pytest from vllm.multimodal.inputs import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + KVCacheBlock, PrefixCachingMetrics, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs(): assert block_hashes[0].extra_keys is None assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].extra_keys is None + + +def test_metrics(): + """ + Test the prefix caching metrics. + """ + + def stats(requests, queries, hits): + return PrefixCacheStats(requests=requests, queries=queries, hits=hits) + + metrics = PrefixCachingMetrics(interval=5) + assert metrics.hit_rate == 0.0 + + metrics.observe(stats(1, 20, 9)) + # 9 / 20 = 0.45 + assert metrics.hit_rate == 0.45 + + metrics.observe(stats(4, 80, 16)) + + # 25 / 100 = 0.25 + assert metrics.hit_rate == 0.25 + + metrics.observe(stats(1, 10, 2)) + + # Remove (20, 9) and add (10, 2): 18 / 90 = 0.2 + assert metrics.aggregated_requests == 5 + assert metrics.aggregated_query_total == 90 + assert metrics.aggregated_query_hit == 18 + assert metrics.hit_rate == 0.2 + + metrics.reset() + assert metrics.hit_rate == 0.0 + assert metrics.aggregated_requests == 0 + assert metrics.aggregated_query_total == 0 + assert metrics.aggregated_query_hit == 0 + assert not metrics.query_queue diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index f8d08d0e4..f75d31f54 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus logger = init_logger(__name__) @@ -78,11 +79,28 @@ class KVCacheManager: self.req_to_block_hashes: DefaultDict[ str, List[BlockHashType]] = defaultdict(list) + self.prefix_cache_stats = PrefixCacheStats() + @property def usage(self) -> float: + """Get the KV cache usage. + + Returns: + The KV cache usage (between 0.0 and 1.0). + """ return 1.0 - (self.free_block_queue.num_free_blocks / self.num_gpu_blocks) + def make_prefix_cache_stats(self) -> PrefixCacheStats: + """Get (and reset) the prefix cache stats. + + Returns: + The current prefix caching stats. + """ + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + def get_computed_blocks( self, request: Request) -> Tuple[List[KVCacheBlock], int]: """Get the computed (cached) blocks for the request. @@ -118,6 +136,10 @@ class KVCacheManager: else: break + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. @@ -280,6 +302,8 @@ class KVCacheManager: for block in self.block_pool: block.reset_hash() + self.prefix_cache_stats.reset = True + logger.info("Successfully reset prefix cache") return True diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6888f1a3e..bddb482d2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" +from collections import deque from collections.abc import Sequence from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple @@ -8,6 +9,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, KVCacheTensor) +from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request logger = init_logger(__name__) @@ -28,6 +30,68 @@ class BlockHashType(NamedTuple): extra_keys: Optional[Any] = None +class PrefixCachingMetrics: + """Metrics for prefix caching with a hit rate of the most recent N requests. + + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, interval: int = 1000): + self.interval = interval + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue: deque[Tuple[int, int, int]] = deque() + + def observe(self, stats: PrefixCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `interval` requests, the oldest set of + requestsare removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.interval: + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + @dataclass class KVCacheBlock: """KV-cache block metadata.""" diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 1c54914d1..985fcf01b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -593,4 +593,5 @@ class Scheduler: num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), ) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index eb1acf584..3472761dc 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,6 +9,7 @@ import prometheus_client from vllm.config import ModelConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats @@ -37,6 +38,9 @@ class LoggingStatLogger(StatLoggerBase): self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = [] + # Prefix cache metrics. TODO: Make the interval configurable. + self.prefix_caching_metrics = PrefixCachingMetrics() + def _local_interval_elapsed(self, now: float) -> bool: # Log every _LOCAL_LOGGING_INTERVAL_SEC. elapsed_time = now - self.last_log_time @@ -58,6 +62,8 @@ class LoggingStatLogger(StatLoggerBase): self._track_iteration_stats(iteration_stats) + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + now = time.monotonic() if not self._local_interval_elapsed(now): return @@ -72,13 +78,15 @@ class LoggingStatLogger(StatLoggerBase): logger.info( "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs " - "GPU KV cache usage: %.1f%%.", + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " + "Prefix cache hit rate: %.1f%%", prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, scheduler_stats.gpu_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, ) @@ -107,6 +115,18 @@ class PrometheusStatLogger(StatLoggerBase): documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames).labels(*labelvalues) + self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_queries", + documentation= + "GPU prefix cache queries, in terms of number of queried blocks.", + labelnames=labelnames).labels(*labelvalues) + + self.counter_gpu_prefix_cache_hits = prometheus_client.Counter( + name="vllm:gpu_prefix_cache_hits", + documentation= + "GPU prefix cache hits, in terms of number of cached blocks.", + labelnames=labelnames).labels(*labelvalues) + self.counter_prompt_tokens = prometheus_client.Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -170,6 +190,11 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) + self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( iteration_stats.num_generation_tokens) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5e588d35e..f806b0adf 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, List if TYPE_CHECKING: @@ -9,6 +9,20 @@ if TYPE_CHECKING: from vllm.v1.engine import EngineCoreOutput, FinishReason +@dataclass +class PrefixCacheStats: + """Stores prefix cache hit statistics.""" + # Whether reset_prefix_cache was invoked. + reset: bool = False + # The number of requests in this update. + requests: int = 0 + # The number of queries in these requests. Note that "queries" here + # means the number of blocks that were queried from the cache. + queries: int = 0 + # The number of hits in these requests. + hits: int = 0 + + @dataclass class SchedulerStats: """Stats associated with the scheduler.""" @@ -17,7 +31,9 @@ class SchedulerStats: num_waiting_reqs: int = 0 gpu_cache_usage: float = 0.0 - # gpu_prefix_cache_hit_rate: float = 0.0 + + prefix_cache_stats: PrefixCacheStats = field( + default_factory=PrefixCacheStats) @dataclass -- GitLab From 9cf4759493919580011f03812abf16387eafe18c Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 11 Feb 2025 21:20:53 +0800 Subject: [PATCH 068/253] [executor] init `local_rank` as device index (#13027) Signed-off-by: Mengqing Cao --- vllm/executor/uniproc_executor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index e5464cafa..94db23224 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -28,6 +28,11 @@ class UniProcExecutor(ExecutorBase): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + if len(device_info) > 1: + local_rank = int(device_info[1]) rank = 0 kwargs = dict( vllm_config=self.vllm_config, -- GitLab From 7539bbc6a6715dc8e5e71730e2377219b0e69e21 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Tue, 11 Feb 2025 08:47:10 -0500 Subject: [PATCH 069/253] [ROCm] Using a more precise memory profiling (#12624) Signed-off-by: Gregory Shtrasberg --- vllm/platforms/rocm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1f690b711..13aebc605 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -169,4 +169,5 @@ class RocmPlatform(Platform): device: Optional[torch.types.Device] = None ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.max_memory_allocated(device) + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( + device)[0] -- GitLab From da317197dd9352a7718d9bf697c2a5aeb9d42b41 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Tue, 11 Feb 2025 21:55:57 +0800 Subject: [PATCH 070/253] [Build] Fix cuda link target of cumem_allocator in CPU env (#12863) Signed-off-by: YuhongGuo Co-authored-by: Tyler Michael Smith --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b99061dfd..a0fd346c6 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,7 +192,7 @@ set_gencode_flags_for_srcs( if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Enabling cumem allocator extension.") # link against cuda driver library - list(APPEND CUMEM_LIBS cuda) + list(APPEND CUMEM_LIBS CUDA::cuda_driver) define_gpu_extension_target( cumem_allocator DESTINATION vllm -- GitLab From 2e3b969ec0d46e2cfff041a07f29a2ca4bb82bbd Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 11 Feb 2025 22:06:46 +0800 Subject: [PATCH 071/253] [Platform] add pre_register_and_update function (#12432) Signed-off-by: wangxiyuan --- vllm/config.py | 3 ++- vllm/engine/arg_utils.py | 21 +++++++++++++++++++++ vllm/platforms/interface.py | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 426ba3808..1d8c42dd2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3057,7 +3057,8 @@ class VllmConfig: kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore # some opaque config, only used to provide additional information - # for the hash computation, mainly used for testing and debugging. + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. additional_config: SupportsHash = field(default=None, init=True) # type: ignore instance_id: str = "" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 40c6fb456..4232ad920 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,6 +20,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.plugins import load_general_plugins from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean @@ -203,6 +204,8 @@ class EngineArgs: calculate_kv_scales: Optional[bool] = None + additional_config: Optional[Dict[str, Any]] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -984,6 +987,14 @@ class EngineArgs: 'be loaded from the model checkpoint if available. ' 'Otherwise, the scales will default to 1.0.') + parser.add_argument( + "--additional-config", + type=json.loads, + default=None, + help="Additional config for specified platform in JSON format. " + "Different platforms may support different configs. Make sure the " + "configs are valid for the platform you are using. The input format" + " is like '{\"config_key\":\"config_value\"}'") return parser @classmethod @@ -1044,6 +1055,9 @@ class EngineArgs: def create_engine_config(self, usage_context: Optional[UsageContext] = None ) -> VllmConfig: + from vllm.platforms import current_platform + current_platform.pre_register_and_update() + if envs.VLLM_USE_V1: self._override_v1_engine_args(usage_context) @@ -1287,6 +1301,7 @@ class EngineArgs: prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, + additional_config=self.additional_config, ) if envs.VLLM_USE_V1: @@ -1347,6 +1362,12 @@ class AsyncEngineArgs(EngineArgs): parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') + # Initialize plugin to update the parser, for example, The plugin may + # adding a new kind of quantization method to --quantization argument or + # a new device to --device argument. + load_general_plugins() + from vllm.platforms import current_platform + current_platform.pre_register_and_update(parser) return parser diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 645d98a1b..61673b085 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -13,8 +13,10 @@ from vllm.logger import init_logger if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.utils import FlexibleArgumentParser else: VllmConfig = None + FlexibleArgumentParser = None logger = init_logger(__name__) @@ -223,6 +225,22 @@ class Platform: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + """ + Do some pre-registeration or update action for the current platform. + + This function is called before global VllmConfig is initialized or cli + arguments are parsed. It's used for out-of-tree platforms to register or + update the configuration. + + For example, the out-of-tree quantization config can be imported and + registered here dynamically. + """ + pass + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: """ -- GitLab From 110f59a33e22aaa16a1d0278bb19f76e4fe5f5a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Tue, 11 Feb 2025 20:11:20 +0530 Subject: [PATCH 072/253] [Bugfix] fix flaky test (#13089) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: மனோஜ்குமார் பழனிச்சாமி --- tests/test_seed_behavior.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/tests/test_seed_behavior.py b/tests/test_seed_behavior.py index 7e4e71563..c45ed6926 100644 --- a/tests/test_seed_behavior.py +++ b/tests/test_seed_behavior.py @@ -8,32 +8,17 @@ from vllm.platforms.interface import Platform def test_seed_behavior(): - # Test with seed=None - Platform.seed_everything(None) + # Test with a specific seed + Platform.seed_everything(42) random_value_1 = random.randint(0, 100) np_random_value_1 = np.random.randint(0, 100) torch_random_value_1 = torch.randint(0, 100, (1, )).item() - Platform.seed_everything(None) + Platform.seed_everything(42) random_value_2 = random.randint(0, 100) np_random_value_2 = np.random.randint(0, 100) torch_random_value_2 = torch.randint(0, 100, (1, )).item() - assert random_value_1 != random_value_2 - assert np_random_value_1 != np_random_value_2 - assert torch_random_value_1 != torch_random_value_2 - - # Test with a specific seed - Platform.seed_everything(42) - random_value_3 = random.randint(0, 100) - np_random_value_3 = np.random.randint(0, 100) - torch_random_value_3 = torch.randint(0, 100, (1, )).item() - - Platform.seed_everything(42) - random_value_4 = random.randint(0, 100) - np_random_value_4 = np.random.randint(0, 100) - torch_random_value_4 = torch.randint(0, 100, (1, )).item() - - assert random_value_3 == random_value_4 - assert np_random_value_3 == np_random_value_4 - assert torch_random_value_3 == torch_random_value_4 + assert random_value_1 == random_value_2 + assert np_random_value_1 == np_random_value_2 + assert torch_random_value_1 == torch_random_value_2 -- GitLab From 75e6e145164c8e47a97b6e29654fe81b2fbc1ff5 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 11 Feb 2025 15:14:00 +0000 Subject: [PATCH 073/253] [V1][Metrics] Add several request timing histograms (#12644) Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 31 +++++++ tests/v1/core/test_scheduler.py | 3 +- tests/v1/engine/test_engine_core.py | 6 +- tests/v1/engine/test_engine_core_client.py | 2 + tests/v1/engine/test_output_processor.py | 23 +++-- vllm/v1/core/kv_cache_manager.py | 3 + vllm/v1/core/scheduler.py | 33 +++++++- vllm/v1/engine/__init__.py | 33 +++++++- vllm/v1/engine/async_llm.py | 24 +++--- vllm/v1/engine/core.py | 10 ++- vllm/v1/engine/core_client.py | 19 +++-- vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/output_processor.py | 59 +++++++++---- vllm/v1/metrics/loggers.py | 49 +++++++++++ vllm/v1/metrics/stats.py | 97 +++++++++++++++++----- vllm/v1/request.py | 25 ++++-- 16 files changed, 334 insertions(+), 84 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 8c1bb1a89..34b648b6e 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -85,6 +85,10 @@ EXPECTED_VALUES = { "vllm:time_per_output_token_seconds": [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_prompt_tokens": [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), ("_count", _NUM_REQUESTS)], @@ -169,6 +173,18 @@ EXPECTED_METRICS = [ "vllm:e2e_request_latency_seconds_sum", "vllm:e2e_request_latency_seconds_bucket", "vllm:e2e_request_latency_seconds_count", + "vllm:request_queue_time_seconds_sum", + "vllm:request_queue_time_seconds_bucket", + "vllm:request_queue_time_seconds_count", + "vllm:request_inference_time_seconds_sum", + "vllm:request_inference_time_seconds_bucket", + "vllm:request_inference_time_seconds_count", + "vllm:request_prefill_time_seconds_sum", + "vllm:request_prefill_time_seconds_bucket", + "vllm:request_prefill_time_seconds_count", + "vllm:request_decode_time_seconds_sum", + "vllm:request_decode_time_seconds_bucket", + "vllm:request_decode_time_seconds_count", "vllm:request_prompt_tokens_sum", "vllm:request_prompt_tokens_bucket", "vllm:request_prompt_tokens_count", @@ -220,6 +236,21 @@ EXPECTED_METRICS_V1 = [ "vllm:time_per_output_token_seconds_sum", "vllm:time_per_output_token_seconds_bucket", "vllm:time_per_output_token_seconds_count", + "vllm:e2e_request_latency_seconds_sum", + "vllm:e2e_request_latency_seconds_bucket", + "vllm:e2e_request_latency_seconds_count", + "vllm:request_queue_time_seconds_sum", + "vllm:request_queue_time_seconds_bucket", + "vllm:request_queue_time_seconds_count", + "vllm:request_inference_time_seconds_sum", + "vllm:request_inference_time_seconds_bucket", + "vllm:request_inference_time_seconds_count", + "vllm:request_prefill_time_seconds_sum", + "vllm:request_prefill_time_seconds_bucket", + "vllm:request_prefill_time_seconds_count", + "vllm:request_decode_time_seconds_sum", + "vllm:request_decode_time_seconds_bucket", + "vllm:request_decode_time_seconds_count", ] diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0d29729a4..8aba46aec 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -38,7 +38,8 @@ def create_scheduler( return Scheduler(scheduler_config, model_config, cache_config, - lora_config=None) + lora_config=None, + log_stats=True) def create_requests( diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 6a91f1901..36b31550d 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -50,7 +50,8 @@ def test_engine_core(monkeypatch): executor_class = Executor.get_class(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class) + executor_class=executor_class, + log_stats=True) """Test basic request lifecycle.""" # First request. @@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch): executor_class = Executor.get_class(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class) + executor_class=executor_class, + log_stats=True) """Test basic request lifecycle.""" # First request. request: EngineCoreRequest = make_request() diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index b2539132f..45080be8e 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, + log_stats=False, ) MAX_TOKENS = 20 @@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch): asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, + log_stats=True, ) MAX_TOKENS = 20 diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index c8f43edb7..1d47df417 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import time from typing import Dict, List, Optional import pytest @@ -15,6 +16,7 @@ from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import OutputProcessor +from vllm.v1.metrics.stats import IterationStats def _ref_convert_id_to_token( @@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors): output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, log_stats=True) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) + engine_core_timestamp = time.monotonic() # Make N requests. requests = [ @@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors): # First iteration has 2 prefills. outputs = engine_core.get_outputs()[:num_active] - processed_outputs = output_processor.process_outputs(outputs) - iteration_stats = processed_outputs.iteration_stats + iteration_stats = IterationStats() + output_processor.process_outputs(outputs, engine_core_timestamp, + iteration_stats) total_prompt_tokens = sum([ len(prompt_tokens) for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] @@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] - processed_outputs = output_processor.process_outputs(outputs) - iteration_stats = processed_outputs.iteration_stats + iteration_stats = IterationStats() + output_processor.process_outputs(outputs, engine_core_timestamp, + iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active @@ -652,8 +657,9 @@ def test_iteration_stats(dummy_test_vectors): output_processor.add_request(inactive_request) num_active += 1 outputs = engine_core.get_outputs()[:num_active] - processed_outputs = output_processor.process_outputs(outputs) - iteration_stats = processed_outputs.iteration_stats + iteration_stats = IterationStats() + output_processor.process_outputs(outputs, engine_core_timestamp, + iteration_stats) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens @@ -661,8 +667,9 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] - processed_outputs = output_processor.process_outputs(outputs) - iteration_stats = processed_outputs.iteration_stats + iteration_stats = IterationStats() + output_processor.process_outputs(outputs, engine_core_timestamp, + iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index f75d31f54..0381e5cdd 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -26,6 +26,7 @@ class KVCacheManager: sliding_window: Optional[int] = None, enable_caching: bool = True, num_preallocate_tokens: int = 64, + log_stats: bool = False, ) -> None: self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks @@ -33,6 +34,8 @@ class KVCacheManager: self.max_num_blocks_per_req = cdiv(max_model_len, block_size) self.sliding_window = sliding_window self.enable_caching = enable_caching + # FIXME: make prefix cache stats conditional on log_stats + self.log_stats = log_stats # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end # of its block table, we preallocate N blocks in advance. This way, we diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 985fcf01b..e32e557ae 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import time from collections import deque from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -10,7 +11,8 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, + EngineCoreOutput, EngineCoreOutputs) from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -26,10 +28,12 @@ class Scheduler: model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + log_stats: bool, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config + self.log_stats = log_stats # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -45,7 +49,8 @@ class Scheduler: num_gpu_blocks=num_gpu_blocks, max_model_len=self.max_model_len, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=self.cache_config.enable_prefix_caching, + log_stats=self.log_stats) self.block_size = self.cache_config.block_size # req_id -> Request @@ -107,6 +112,8 @@ class Scheduler: scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens + scheduled_timestamp = time.monotonic() + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -246,6 +253,7 @@ class Scheduler: self.running.append(request) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) + self.request_scheduled(request, scheduled_timestamp) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: @@ -508,7 +516,8 @@ class Scheduler: finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, - stop_reason=request.stop_reason)) + stop_reason=request.stop_reason, + events=request.take_events())) if not stopped: new_running.append(request) @@ -541,6 +550,7 @@ class Scheduler: def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request + self.request_queued(request) def finish_requests( self, @@ -588,7 +598,22 @@ class Scheduler: def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() - def make_stats(self) -> SchedulerStats: + def request_queued(self, request: Request): + if not self.log_stats: + return + request.events.append( + EngineCoreEvent.new_event(EngineCoreEventType.QUEUED)) + + def request_scheduled(self, request: Request, timestamp: float): + if not self.log_stats: + return + request.events.append( + EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, + timestamp)) + + def make_stats(self) -> Optional[SchedulerStats]: + if not self.log_stats: + return None return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 30e118501..782fdcee3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import time from typing import List, Optional, Union import msgspec @@ -60,6 +61,30 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] +class EngineCoreEventType(enum.IntEnum): + """The type of engine core request event.""" + QUEUED = 1 + SCHEDULED = 2 + + +class EngineCoreEvent(msgspec.Struct): + """A timestamped engine core event associated with a request. + + The timestamp is a monotonic timestamps and is used for by the engine + frontend to calculate intervals between engine core events. These + timestamps should not be compared with timestamps from other processes. + """ + type: EngineCoreEventType + timestamp: float + + @classmethod + def new_event(cls, + event_type: EngineCoreEventType, + timestamp: Optional[float] = None) -> "EngineCoreEvent": + timestamp = time.monotonic() if timestamp is None else timestamp + return cls(event_type, timestamp) + + class EngineCoreOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] @@ -74,6 +99,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None + events: Optional[List[EngineCoreEvent]] = None @property def finished(self) -> bool: @@ -91,7 +117,12 @@ class EngineCoreOutputs( # [num_reqs] outputs: List[EngineCoreOutput] - scheduler_stats: SchedulerStats + scheduler_stats: Optional[SchedulerStats] + timestamp: float = 0.0 + + def __post_init__(self): + if self.timestamp == 0.0: + self.timestamp = time.monotonic() class EngineCoreRequestType(enum.Enum): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3c4e35e4a..f19d2ed8b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -53,10 +53,12 @@ class AsyncLLM(EngineClient): self.log_requests = log_requests self.log_stats = log_stats - self.stat_loggers: List[StatLoggerBase] = [ - LoggingStatLogger(), - PrometheusStatLogger(vllm_config.model_config), - ] + self.stat_loggers: List[StatLoggerBase] = [] + if self.log_stats: + self.stat_loggers.extend([ + LoggingStatLogger(), + PrometheusStatLogger(vllm_config.model_config), + ]) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -85,6 +87,7 @@ class AsyncLLM(EngineClient): asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, + log_stats=self.log_stats, ) self.output_handler: Optional[asyncio.Task] = None @@ -246,6 +249,8 @@ class AsyncLLM(EngineClient): # 1) Pull EngineCoreOutputs from the EngineCore. outputs = await self.engine_core.get_output_async() + iteration_stats = IterationStats() if self.log_stats else None + # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. @@ -257,14 +262,12 @@ class AsyncLLM(EngineClient): outputs.outputs, cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) - iteration_stats = None for i, outputs_slice in enumerate(slices): # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( - outputs_slice, iteration_stats) + outputs_slice, outputs.timestamp, iteration_stats) # NOTE: RequestOutputs are pushed to their queues. assert not processed_outputs.request_outputs - iteration_stats = processed_outputs.iteration_stats # Allow other asyncio tasks to run between chunks if i + 1 < len(slices): @@ -277,7 +280,6 @@ class AsyncLLM(EngineClient): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - assert iteration_stats is not None self._log_stats( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, @@ -299,12 +301,14 @@ class AsyncLLM(EngineClient): def _log_stats( self, - scheduler_stats: SchedulerStats, - iteration_stats: IterationStats, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], ): if not self.log_stats: return + assert scheduler_stats is not None + assert iteration_stats is not None for logger in self.stat_loggers: logger.log(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c90667ba0..e4677681b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -38,12 +38,15 @@ class EngineCore: self, vllm_config: VllmConfig, executor_class: Type[Executor], + log_stats: bool, ): assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) + self.log_stats = log_stats + # Setup Model. self.model_executor = executor_class(vllm_config) @@ -59,6 +62,7 @@ class EngineCore: model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + log_stats=self.log_stats, ) self.mm_input_mapper_server = MMInputMapperServer( @@ -148,11 +152,9 @@ class EngineCoreProc(EngineCore): ready_pipe: Connection, vllm_config: VllmConfig, executor_class: Type[Executor], - log_stats: bool = False, + log_stats: bool, ): - super().__init__(vllm_config, executor_class) - - self.log_stats = log_stats + super().__init__(vllm_config, executor_class, log_stats) # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 2d7d6b42c..b3de5cdc2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -41,6 +41,7 @@ class EngineCoreClient(ABC): asyncio_mode: bool, vllm_config: VllmConfig, executor_class: Type[Executor], + log_stats: bool, ) -> "EngineCoreClient": # TODO: support this for debugging purposes. @@ -50,12 +51,12 @@ class EngineCoreClient(ABC): "is not currently supported.") if multiprocess_mode and asyncio_mode: - return AsyncMPClient(vllm_config, executor_class) + return AsyncMPClient(vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: - return SyncMPClient(vllm_config, executor_class) + return SyncMPClient(vllm_config, executor_class, log_stats) - return InprocClient(vllm_config, executor_class) + return InprocClient(vllm_config, executor_class, log_stats) @abstractmethod def shutdown(self): @@ -204,13 +205,13 @@ class MPClient(EngineCoreClient): class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, - executor_class: Type[Executor]): + def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], + log_stats: bool): super().__init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, + log_stats=log_stats, ) def get_output(self) -> EngineCoreOutputs: @@ -245,13 +246,13 @@ class SyncMPClient(MPClient): class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, - executor_class: Type[Executor]): + def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor], + log_stats: bool): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, - log_stats=True, + log_stats=log_stats, ) self.outputs_queue: Optional[asyncio.Queue[bytes]] = None diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 3ef5a9706..c9a4c5369 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -73,6 +73,7 @@ class LLMEngine: asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, + log_stats=False, # FIXME: implement ) @classmethod diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5dbf530ca..7973c62c3 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -19,7 +19,6 @@ class OutputProcessorOutput: request_outputs: List[RequestOutput] reqs_to_abort: List[str] - iteration_stats: IterationStats class RequestState: @@ -34,6 +33,7 @@ class RequestState: detokenizer: IncrementalDetokenizer, arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], + log_stats: bool, ): self.request_id = request_id self.output_kind = output_kind @@ -45,14 +45,16 @@ class RequestState: self.is_prefilling = True self.queue = queue - self.stats = RequestStateStats(last_token_time=arrival_time) + self.stats = RequestStateStats( + arrival_time=arrival_time) if log_stats else None @classmethod def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, - queue: Optional[asyncio.Queue[RequestOutput]] = None, + queue: Optional[asyncio.Queue[RequestOutput]], + log_stats: bool, ) -> "RequestState": return cls( request_id=request.request_id, @@ -69,6 +71,7 @@ class RequestState: ), arrival_time=request.arrival_time, queue=queue, + log_stats=log_stats, ) @@ -112,11 +115,13 @@ class OutputProcessor: self.request_states[request_id] = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, - queue=queue) + queue=queue, + log_stats=self.log_stats) def process_outputs( self, engine_core_outputs: List[EngineCoreOutput], + engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, ) -> OutputProcessorOutput: """ @@ -145,8 +150,6 @@ class OutputProcessor: request_outputs: List[RequestOutput] = [] reqs_to_abort: List[str] = [] - if not iteration_stats: - iteration_stats = IterationStats(self.log_stats) for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id req_state = self.request_states.get(req_id) @@ -155,10 +158,9 @@ class OutputProcessor: continue # 1) Compute stats for this iteration. - iteration_stats.update_from_output(engine_core_output, - req_state.is_prefilling, - req_state.prompt_len, - req_state.stats) + self._update_stats_from_output(req_state, engine_core_output, + engine_core_timestamp, + iteration_stats) new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason @@ -205,17 +207,44 @@ class OutputProcessor: # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - # Track per-request stats. - assert finish_reason is not None - iteration_stats.update_from_finished_request( - finish_reason, request_output, req_state.stats) + # Track per-request stats + self._update_stats_from_finished(req_state, request_output, + finish_reason, + iteration_stats) return OutputProcessorOutput( request_outputs=request_outputs, reqs_to_abort=reqs_to_abort, - iteration_stats=iteration_stats, ) + def _update_stats_from_output(self, req_state: RequestState, + engine_core_output: EngineCoreOutput, + engine_core_timestamp: Optional[float], + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + + assert engine_core_timestamp is not None + assert req_state.stats is not None + iteration_stats.update_from_output(engine_core_output, + engine_core_timestamp, + req_state.is_prefilling, + req_state.prompt_len, + req_state.stats) + + def _update_stats_from_finished(self, req_state: RequestState, + request_output: RequestOutput, + finish_reason: Optional[FinishReason], + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + + assert finish_reason is not None + assert req_state.stats is not None + iteration_stats.update_from_finished_request(finish_reason, + request_output, + req_state.stats) + @staticmethod def _make_request_output( request_state: RequestState, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3472761dc..439be38a3 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -182,6 +182,45 @@ class PrometheusStatLogger(StatLoggerBase): ], labelnames=labelnames).labels(*labelvalues) + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0 + ] + self.histogram_e2e_time_request = \ + prometheus_client.Histogram( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of e2e request latency in seconds.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_queue_time_request = \ + prometheus_client.Histogram( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_inference_time_request = \ + prometheus_client.Histogram( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_prefill_time_request = \ + prometheus_client.Histogram( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + self.histogram_decode_time_request = \ + prometheus_client.Histogram( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames).labels(*labelvalues) + def log(self, scheduler_stats: SchedulerStats, iteration_stats: IterationStats): """Log to prometheus.""" @@ -201,6 +240,12 @@ class PrometheusStatLogger(StatLoggerBase): for finished_request in iteration_stats.finished_requests: self.counter_request_success[finished_request.finish_reason].inc() + self.histogram_e2e_time_request.observe( + finished_request.e2e_latency) + self.histogram_inference_time_request.observe( + finished_request.inference_time) + self.histogram_decode_time_request.observe( + finished_request.decode_time) self.histogram_num_prompt_tokens_request.observe( finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request.observe( @@ -210,6 +255,10 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_to_first_token.observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: self.histogram_time_per_output_token.observe(tpot) + for queue_time in iteration_stats.queue_times_iter: + self.histogram_queue_time_request.observe(queue_time) + for prefill_time in iteration_stats.prefill_times_iter: + self.histogram_prefill_time_request.observe(prefill_time) @staticmethod def _unregister_vllm_metrics(): diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index f806b0adf..a0e620492 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List if TYPE_CHECKING: from vllm.outputs import RequestOutput - from vllm.v1.engine import EngineCoreOutput, FinishReason + from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason @dataclass @@ -41,7 +41,15 @@ class RequestStateStats: """Stats that need to be tracked across delta updates.""" num_generation_tokens: int = 0 - last_token_time: float = 0.0 + + # This is a engine frontend timestamp (wall-clock) + arrival_time: float = 0.0 + + # These are engine core timestamps (monotonic) + queued_ts: float = 0.0 + scheduled_ts: float = 0.0 + first_token_ts: float = 0.0 + last_token_ts: float = 0.0 @dataclass @@ -49,33 +57,37 @@ class FinishedRequestStats: """Stats associated with a finished request.""" finish_reason: "FinishReason" + e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 + inference_time: float = 0.0 + decode_time: float = 0.0 class IterationStats: """Stats associated with a single set of EngineCoreOutputs.""" - def __init__(self, log_stats: bool): - self.log_stats = log_stats + def __init__(self): + self.iteration_timestamp = time.time() self.num_generation_tokens = 0 self.num_prompt_tokens = 0 self.finished_requests: List[FinishedRequestStats] = [] self.time_to_first_tokens_iter: List[float] = [] self.time_per_output_tokens_iter: List[float] = [] + self.queue_times_iter: List[float] = [] + self.prefill_times_iter: List[float] = [] - def update_from_output(self, output: "EngineCoreOutput", - is_prefilling: bool, prompt_len: int, - request_state_stats: RequestStateStats): - if not self.log_stats: - return + def _time_since(self, start: float) -> float: + """Calculate an interval relative to this iteration's timestamp.""" + return self.iteration_timestamp - start + def update_from_output(self, output: "EngineCoreOutput", + engine_core_timestamp: float, is_prefilling: bool, + prompt_len: int, req_stats: RequestStateStats): num_new_generation_tokens = len(output.new_token_ids) - now = time.time() - last_token_latency = now - request_state_stats.last_token_time self.num_generation_tokens += num_new_generation_tokens - if is_prefilling: + if is_prefilling and num_new_generation_tokens > 0: # TODO(andy): we used to assert that num_new_generation_tokens # > 0 with an invariant that EngineCore does not stream outputs # for partially completed prefills (scheduler.update_from_output @@ -84,19 +96,58 @@ class IterationStats: # partially completed prompt. # This will be reverted in a follow up PR and we should re-enable # this assertion / invariant. + self.num_prompt_tokens += prompt_len + + first_token_latency = self._time_since(req_stats.arrival_time) + self.time_to_first_tokens_iter.append(first_token_latency) + + req_stats.num_generation_tokens += num_new_generation_tokens + + # Process request-level engine core events + if output.events is not None: + self.update_from_events(output.events, is_prefilling, req_stats) + + # Process the batch-level "new tokens" engine core event + if is_prefilling: + # TODO: re-enable no-output-for-partial-prefills invariant as above if num_new_generation_tokens > 0: - self.num_prompt_tokens += prompt_len - self.time_to_first_tokens_iter.append(last_token_latency) + prefill_interval = \ + engine_core_timestamp - req_stats.scheduled_ts + self.prefill_times_iter.append(prefill_interval) + req_stats.first_token_ts = engine_core_timestamp else: - self.time_per_output_tokens_iter.append(last_token_latency) - - request_state_stats.num_generation_tokens += num_new_generation_tokens - request_state_stats.last_token_time = now + tpot = engine_core_timestamp - req_stats.last_token_ts + self.time_per_output_tokens_iter.append(tpot) + + # TODO: re-enable no-output-for-partial-prefills invariant as above + if num_new_generation_tokens > 0: + req_stats.last_token_ts = engine_core_timestamp + + def update_from_events(self, events: List["EngineCoreEvent"], + is_prefilling: bool, req_stats: RequestStateStats): + # Avoid circular dependency + from vllm.v1.engine import EngineCoreEventType + for event in events: + if event.type == EngineCoreEventType.QUEUED: + req_stats.queued_ts = event.timestamp + elif event.type == EngineCoreEventType.SCHEDULED: + queued_interval = event.timestamp - req_stats.queued_ts + self.queue_times_iter.append(queued_interval) + req_stats.scheduled_ts = event.timestamp def update_from_finished_request(self, finish_reason: "FinishReason", request_output: "RequestOutput", - request_state_stats: RequestStateStats): - self.finished_requests.append( - FinishedRequestStats(finish_reason, - len(request_output.prompt_token_ids), - request_state_stats.num_generation_tokens)) + req_stats: RequestStateStats): + e2e_latency = self._time_since(req_stats.arrival_time) + + inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + decode_time = req_stats.last_token_ts - req_stats.first_token_ts + + finished_req = \ + FinishedRequestStats(finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=len(request_output.prompt_token_ids), + num_generation_tokens=req_stats.num_generation_tokens, + inference_time=inference_time, + decode_time=decode_time) + self.finished_requests.append(finished_req) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index bb4d2c191..0ebaa71ce 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, List, Optional, Union from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams -from vllm.sequence import RequestMetrics -from vllm.v1.engine import EngineCoreRequest, FinishReason +from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, + EngineCoreRequest, FinishReason) from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -33,14 +33,10 @@ class Request: self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) self.lora_request = lora_request self.status = RequestStatus.WAITING + self.events: List[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens @@ -83,6 +79,21 @@ class Request: lora_request=request.lora_request, ) + def queued(self, timestamp: Optional[float] = None) -> None: + self.events.append( + EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, timestamp)) + + def scheduled(self, timestamp: Optional[float] = None) -> None: + self.events.append( + EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, + timestamp)) + + def take_events(self) -> Optional[List[EngineCoreEvent]]: + if not self.events: + return None + events, self.events = self.events, [] + return events + def append_output_token_ids( self, token_ids: Union[int, List[int]], -- GitLab From ad9776353e6b00d019415e94fd17c78ad4575ff7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 11 Feb 2025 15:51:19 +0000 Subject: [PATCH 074/253] Set `torch_dtype` in `TransformersModel` (#13088) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 43d2c88d3..1605467bc 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -143,6 +143,7 @@ class TransformersModel(nn.Module): self.model: PreTrainedModel = AutoModel.from_config( self.config, attn_implementation="vllm", + torch_dtype=vllm_config.model_config.dtype, trust_remote_code=vllm_config.model_config.trust_remote_code, ) prefix = self.model.base_model_prefix -- GitLab From bf3e05215c7f20baf9fcd82d8877d8453dcebf6e Mon Sep 17 00:00:00 2001 From: Jewon Lee <105219284+je1lee@users.noreply.github.com> Date: Wed, 12 Feb 2025 01:20:37 +0900 Subject: [PATCH 075/253] [Misc] Fix typo at comments at metrics.py (#13024) --- vllm/engine/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ce806b4a9..7c55d66e5 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -237,7 +237,7 @@ class Metrics: documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) - # Speculatie decoding stats + # Speculative decoding stats self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( name="vllm:spec_decode_draft_acceptance_rate", documentation="Speulative token acceptance rate.", -- GitLab From 21f5d50fa557f431e9c76d432771337f5399c420 Mon Sep 17 00:00:00 2001 From: MoonRide303 <130458190+MoonRide303@users.noreply.github.com> Date: Tue, 11 Feb 2025 17:21:18 +0100 Subject: [PATCH 076/253] [Bugfix] Do not use resource module on Windows (#12858) (#13029) --- vllm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index e16875276..6a41afff8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -15,7 +15,6 @@ import ipaddress import multiprocessing import os import re -import resource import signal import socket import subprocess @@ -2070,6 +2069,11 @@ def memory_profiling( # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): + if sys.platform.startswith('win'): + logger.info("Windows detected, skipping ulimit adjustment.") + return + + import resource resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) -- GitLab From 6c4dbe23eb85e5d1da00ccaf4923a275d8769a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Tue, 11 Feb 2025 18:21:50 +0200 Subject: [PATCH 077/253] [BugFix] Pop instead of del CUDA_VISIBLE_DEVICES (#12962) Signed-off-by: Hollow Man --- examples/offline_inference/rlhf.py | 2 +- examples/offline_inference/rlhf_colocate.py | 2 +- tests/distributed/test_comm_ops.py | 10 +++++----- tests/distributed/test_custom_all_reduce.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 5000251c0..172d18cbc 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -92,7 +92,7 @@ class MyLLM(LLM): # a hack to make the script work. # stop ray from manipulating CUDA_VISIBLE_DEVICES # at the top-level - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) super().__init__(*args, **kwargs) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index b921bc71f..15dc7edc1 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -59,7 +59,7 @@ class MyLLM(LLM): # a hack to make the script work. # stop ray from manipulating CUDA_VISIBLE_DEVICES # at the top-level - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) # every worker will use 0.4 GPU, so that we can schedule # 2 instances on the same GPUs. os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index bc916e8de..7b0346b8a 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -22,7 +22,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, @@ -44,7 +44,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, @@ -72,7 +72,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, @@ -108,7 +108,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, @@ -148,7 +148,7 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 46887bca4..4928690be 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -24,7 +24,7 @@ for i, v in enumerate(test_sizes): @ray.remote(num_gpus=1, max_calls=1) def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, @@ -80,7 +80,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): @ray.remote(num_gpus=1, max_calls=1) def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): - del os.environ["CUDA_VISIBLE_DEVICES"] + os.environ.pop("CUDA_VISIBLE_DEVICES", None) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, -- GitLab From 2b25b7d2e1bd915dde2890e7a923958c8d1eb8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Tue, 11 Feb 2025 17:38:48 +0100 Subject: [PATCH 078/253] Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023) --- vllm/model_executor/layers/linear.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index da8db08fe..dad161120 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -335,6 +335,12 @@ class ColumnParallelLinear(LinearBase): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) @@ -343,13 +349,12 @@ class ColumnParallelLinear(LinearBase): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) - - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) - is_sharded_weight = getattr(param, "is_sharded_weight", False) - # bitsandbytes loads the weights of the specific portion - # no need to narrow - is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + final_shape = list(loaded_weight.shape) + if output_dim is not None: + tp_size = get_tensor_model_parallel_world_size() + assert final_shape[output_dim] % tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) param_data = param.data if output_dim is not None and not is_sharded_weight: -- GitLab From 565c1efa65358f43a78a52296d658651dd2b8f36 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 12 Feb 2025 00:55:56 +0800 Subject: [PATCH 079/253] [CI/Build][Bugfix] Fix CPU backend default threads num (#13077) --- vllm/platforms/cpu.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 179ee6a7d..a9216c232 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -115,6 +115,9 @@ class CpuPlatform(Platform): # Environment variables for CPU executor # + # Set default threads num for OpenMP parallel + os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads()) + # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" -- GitLab From deb6c1c6b4469984eb2a032099081f7f9e4ec8a8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 11 Feb 2025 18:02:46 +0000 Subject: [PATCH 080/253] [Doc] Improve OpenVINO installation doc (#13102) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../installation/ai_accelerator/openvino.inc.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/getting_started/installation/ai_accelerator/openvino.inc.md b/docs/source/getting_started/installation/ai_accelerator/openvino.inc.md index 112e8d4d9..4f25252d9 100644 --- a/docs/source/getting_started/installation/ai_accelerator/openvino.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/openvino.inc.md @@ -19,17 +19,19 @@ Currently, there are no pre-built OpenVINO wheels. ### Build wheel from source -First, install Python. For example, on Ubuntu 22.04, you can run: +First, install Python and ensure you lave the latest pip. For example, on Ubuntu 22.04, you can run: ```console sudo apt-get update -y sudo apt-get install python3 +pip install --upgrade pip ``` -Second, install prerequisites vLLM OpenVINO backend installation: +Second, clone vLLM and install prerequisites for the vLLM OpenVINO backend installation: ```console -pip install --upgrade pip +git clone https://github.com/vllm-project/vllm.git +cd vllm pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu ``` -- GitLab From 14ecab5be21b2af4ab1bbc6309d558ec620badc6 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 11 Feb 2025 13:17:44 -0500 Subject: [PATCH 081/253] [Bugfix] Guided decoding falls back to outlines when fails to import xgrammar (#12976) Signed-off-by: Yuan Tang --- vllm/model_executor/guided_decoding/__init__.py | 9 +++++++++ vllm/model_executor/guided_decoding/xgrammar_decoding.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index cf96461a5..3eb7d186e 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -40,6 +40,8 @@ def maybe_backend_fallback( guided_params.backend = "outlines" if guided_params.backend == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + xgr_installed) # xgrammar only has x86 wheels for linux, fallback to outlines from vllm.platforms import current_platform if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: @@ -77,6 +79,13 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" + # If the xgrammar module cannot be imported successfully, + # we should still allow users to use guided decoding with a fallback. + elif not xgr_installed: + logger.warning("xgrammar module cannot be imported successfully. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + if (guided_params.backend == "outlines" and guided_params.json_object is not None): # outlines doesn't support json_object, fallback to xgrammar diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index c01bd3af1..fc3a4cd4b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -14,7 +14,9 @@ from transformers import PreTrainedTokenizerFast try: import xgrammar as xgr from xgrammar.base import _core as xgr_core + xgr_installed = True except ImportError: + xgr_installed = False pass from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, -- GitLab From 72c2b68dc9d4fb20eb135c22ee8c86caca48d28b Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 11 Feb 2025 17:34:16 -0500 Subject: [PATCH 082/253] [Misc] Move pre-commit suggestion back to the end (#13114) Signed-off-by: Russell Bryant --- .pre-commit-config.yaml | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 352eb2df0..22b51afdc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -116,13 +116,6 @@ repos: language: python types: [python] exclude: 'vllm/third_party/.*' - - id: suggestion - name: Suggestion - entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' - language: system - verbose: true - pass_filenames: false - exclude: 'vllm/third_party/.*' - id: check-filenames name: Check for spaces in all filenames entry: bash @@ -133,3 +126,12 @@ repos: always_run: true pass_filenames: false exclude: 'vllm/third_party/.*' + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' + language: system + verbose: true + pass_filenames: false + exclude: 'vllm/third_party/.*' + # Insert new entries above the `suggestion` entry -- GitLab From 3ee696a63dd0c2acee44809a3bedec33ea27dfa0 Mon Sep 17 00:00:00 2001 From: Keyun Tong Date: Tue, 11 Feb 2025 20:25:58 -0800 Subject: [PATCH 083/253] [RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518) Signed-off-by: Keyun Tong --- benchmarks/benchmark_serving.py | 5 +- tests/tokenization/test_tokenizer_registry.py | 123 +++++++++++++++ vllm/config.py | 9 +- vllm/engine/arg_utils.py | 6 +- vllm/entrypoints/llm.py | 31 ++-- vllm/entrypoints/openai/serving_engine.py | 3 +- vllm/entrypoints/openai/serving_score.py | 2 +- vllm/logits_process.py | 2 +- vllm/transformers_utils/tokenizer.py | 18 ++- vllm/transformers_utils/tokenizer_base.py | 146 ++++++++++++++++++ vllm/transformers_utils/tokenizers/mistral.py | 39 +++-- 11 files changed, 343 insertions(+), 41 deletions(-) create mode 100644 tests/tokenization/test_tokenizer_registry.py create mode 100644 vllm/transformers_utils/tokenizer_base.py diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 0c8923842..90eb05239 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1275,11 +1275,12 @@ if __name__ == "__main__": '--tokenizer-mode', type=str, default="auto", - choices=['auto', 'slow', 'mistral'], + choices=['auto', 'slow', 'mistral', 'custom'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' - '"mistral" will always use the `mistral_common` tokenizer.') + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') parser.add_argument("--served-model-name", type=str, diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py new file mode 100644 index 000000000..793d38f9c --- /dev/null +++ b/tests/tokenization/test_tokenizer_registry.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer_base import (TokenizerBase, + TokenizerRegistry) + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +class TestTokenizer(TokenizerBase): + + @classmethod + def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer": + return TestTokenizer() + + @property + def all_special_tokens_extended(self) -> List[str]: + raise NotImplementedError() + + @property + def all_special_tokens(self) -> List[str]: + raise NotImplementedError() + + @property + def all_special_ids(self) -> List[int]: + raise NotImplementedError() + + @property + def bos_token_id(self) -> int: + return 0 + + @property + def eos_token_id(self) -> int: + return 1 + + @property + def sep_token(self) -> str: + raise NotImplementedError() + + @property + def pad_token(self) -> str: + raise NotImplementedError() + + @property + def is_fast(self) -> bool: + raise NotImplementedError() + + @property + def vocab_size(self) -> int: + raise NotImplementedError() + + @property + def max_token_id(self) -> int: + raise NotImplementedError() + + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + raise NotImplementedError() + + def get_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + def get_added_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + raise NotImplementedError() + + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + raise NotImplementedError() + + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + raise NotImplementedError() + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + raise NotImplementedError() + + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + raise NotImplementedError() + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + raise NotImplementedError() + + +def test_customized_tokenizer(): + TokenizerRegistry.register("test_tokenizer", + "tests.tokenization.test_tokenizer_registry", + "TestTokenizer") + + tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") + assert isinstance(tokenizer, TestTokenizer) + assert tokenizer.bos_token_id == 0 + assert tokenizer.eos_token_id == 1 + + tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom") + assert isinstance(tokenizer, TestTokenizer) + assert tokenizer.bos_token_id == 0 + assert tokenizer.eos_token_id == 1 diff --git a/vllm/config.py b/vllm/config.py index 1d8c42dd2..1740871e7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -102,8 +102,9 @@ class ModelConfig: it; otherwise, you must specify explicitly which task to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, "slow" will always use the slow tokenizer, and - "mistral" will always use the tokenizer from `mistral_common`. + available, "slow" will always use the slow tokenizer, + "mistral" will always use the tokenizer from `mistral_common`, and + "custom" will use --tokenizer to select the preregistered tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. allowed_local_media_path: Allowing API requests to read local images or @@ -467,10 +468,10 @@ class ModelConfig: def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() - if tokenizer_mode not in ["auto", "slow", "mistral"]: + if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto', 'slow' or 'mistral'.") + "either 'auto', 'slow', 'mistral' or 'custom'.") self.tokenizer_mode = tokenizer_mode def _get_preferred_task( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4232ad920..83ee6b97f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -284,11 +284,13 @@ class EngineArgs: '--tokenizer-mode', type=str, default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow', 'mistral'], + choices=['auto', 'slow', 'mistral', 'custom'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' - '"mistral" will always use the `mistral_common` tokenizer.') + '"mistral" will always use the `mistral_common` tokenizer. \n* ' + '"custom" will use --tokenizer to select the ' + 'preregistered tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d071a0b3c..73593f0c6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1051,9 +1051,9 @@ class LLM: def _cross_encoding_score( self, - tokenizer: Union[AnyTokenizer], - text_1: List[Union[str, TextPrompt, TokensPrompt]], - text_2: List[Union[str, TextPrompt, TokensPrompt]], + tokenizer: AnyTokenizer, + text_1: List[str], + text_2: List[str], truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -1176,29 +1176,36 @@ class LLM: if isinstance(text_1, (str, dict)): # Convert a single prompt to a list. text_1 = [text_1] - text_1 = [ensure_str(t) for t in text_1] + input_text_1: List[str] = [ensure_str(t) for t in text_1] if isinstance(text_2, (str, dict)): # Convert a single prompt to a list. text_2 = [text_2] - text_2 = [ensure_str(t) for t in text_2] + input_text_2: List[str] = [ensure_str(t) for t in text_2] - if len(text_1) > 1 and len(text_1) != len(text_2): + if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2): raise ValueError("Input lengths must be either 1:1, 1:N or N:N") - if len(text_1) == 0: + if len(input_text_1) == 0: raise ValueError("At least one text element must be given") - if len(text_2) == 0: + if len(input_text_2) == 0: raise ValueError("At least one text_pair element must be given") if self.llm_engine.model_config.is_cross_encoder: - return self._cross_encoding_score(tokenizer, text_1, text_2, + return self._cross_encoding_score(tokenizer, input_text_1, + input_text_2, truncate_prompt_tokens, use_tqdm, lora_request, prompt_adapter_request) else: - return self._embedding_score(tokenizer, text_1, text_2, - truncate_prompt_tokens, use_tqdm, - lora_request, prompt_adapter_request) + + return self._embedding_score( + tokenizer, + input_text_1, # type: ignore[arg-type] + input_text_2, # type: ignore[arg-type] + truncate_prompt_tokens, + use_tqdm, + lora_request, + prompt_adapter_request) def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d39fdcb7..9efb5e6fa 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -400,8 +400,7 @@ class OpenAIServing: _chat_template_kwargs.update(chat_template_kwargs or {}) request_prompt: Union[str, List[int]] - is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) - if is_mistral_tokenizer: + if isinstance(tokenizer, MistralTokenizer): request_prompt = apply_mistral_chat_template( tokenizer, messages=messages, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 832aa8516..c7597808f 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing): tokenize_async = make_async(tokenizer.__call__, executor=self._tokenizer_executor) - prompt_inputs = await tokenize_async(text=q, + prompt_inputs = await tokenize_async(q, text_pair=t, **tokenization_kwargs) diff --git a/vllm/logits_process.py b/vllm/logits_process.py index d02072e8f..a810be7bc 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -31,7 +31,7 @@ def get_bad_words_logits_processors( if isinstance(tokenizer, MistralTokenizer): # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(prompt=prompt) + prompt_token_ids = tokenizer.encode(text=prompt) else: prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 520870b56..0c0f68ac1 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_base import (TokenizerBase, + TokenizerRegistry) from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async @@ -21,7 +23,7 @@ from vllm.utils import make_async logger = init_logger(__name__) AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - MistralTokenizer] + TokenizerBase] def decode_tokens( @@ -47,11 +49,7 @@ def encode_tokens( Backend-agnostic equivalent of HF's :code:`tokenizer.encode(text, add_special_tokens=...)`. """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - elif add_special_tokens is not None: + if add_special_tokens is not None: return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text) @@ -183,9 +181,17 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) + + tokenizer: AnyTokenizer if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) + elif tokenizer_mode == "custom": + tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs) else: try: tokenizer = AutoTokenizer.from_pretrained( diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py new file mode 100644 index 000000000..bb5ddaf88 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_base.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +class TokenizerBase(ABC): + + @property + @abstractmethod + def all_special_tokens_extended(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_tokens(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_ids(self) -> List[int]: + raise NotImplementedError() + + @property + @abstractmethod + def bos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def eos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def sep_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def pad_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def is_fast(self) -> bool: + raise NotImplementedError() + + @property + @abstractmethod + def vocab_size(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def max_token_id(self) -> int: + raise NotImplementedError() + + def __len__(self) -> int: + return self.vocab_size + + @abstractmethod + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + raise NotImplementedError() + + @abstractmethod + def get_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def get_added_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def convert_tokens_to_string(self, tokens: List[str]) -> str: + raise NotImplementedError() + + @abstractmethod + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + raise NotImplementedError() + + @abstractmethod + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + raise NotImplementedError() + + +class TokenizerRegistry: + # Tokenizer name -> (tokenizer module, tokenizer class) + REGISTRY: Dict[str, Tuple[str, str]] = {} + + @staticmethod + def register(name: str, module: str, class_name: str) -> None: + TokenizerRegistry.REGISTRY[name] = (module, class_name) + + @staticmethod + def get_tokenizer( + tokenizer_name: str, + *args, + **kwargs, + ) -> TokenizerBase: + tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name) + if tokenizer_cls is None: + raise ValueError(f"Tokenizer {tokenizer_name} not found.") + + tokenizer_module = importlib.import_module(tokenizer_cls[0]) + class_ = getattr(tokenizer_module, tokenizer_cls[1]) + return class_.from_pretrained(*args, **kwargs) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index f08923e74..59131a9d7 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -10,6 +10,7 @@ import huggingface_hub from huggingface_hub import HfApi, hf_hub_download from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_base import TokenizerBase from vllm.utils import is_list_of if TYPE_CHECKING: @@ -140,7 +141,7 @@ def make_mistral_chat_completion_request( tools=tools) # type: ignore[type-var] -class MistralTokenizer: +class MistralTokenizer(TokenizerBase): def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: self.mistral = tokenizer @@ -251,6 +252,14 @@ class MistralTokenizer: def eos_token_id(self) -> int: return self.tokenizer.eos_id + @property + def sep_token(self) -> str: + raise NotImplementedError() + + @property + def pad_token(self) -> str: + raise NotImplementedError() + @property def is_fast(self) -> bool: return True @@ -268,25 +277,26 @@ class MistralTokenizer: def __call__( self, - prompt: Union[str, List[str], List[int]], + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, max_length: Optional[int] = None, ): input_ids: Union[List[int], List[List[int]]] # For List[str], original prompt text - if is_list_of(prompt, str): + if is_list_of(text, str): input_ids_: List[List[int]] = [] - for p in prompt: + for p in text: each_input_ids = self.encode_one(p, truncation, max_length) input_ids_.append(each_input_ids) input_ids = input_ids_ # For List[int], apply chat template output, already tokens. - elif is_list_of(prompt, int): - input_ids = prompt + elif is_list_of(text, int): + input_ids = text # For str, single prompt text else: - input_ids = self.encode_one(prompt, truncation, max_length) + input_ids = self.encode_one(text, truncation, max_length) return Encoding(input_ids=input_ids) def get_vocab(self) -> Dict[str, int]: @@ -300,22 +310,29 @@ class MistralTokenizer: def encode_one( self, - prompt: str, + text: str, truncation: bool = False, max_length: Optional[int] = None, ) -> List[int]: # Mistral Tokenizers should not add special tokens - input_ids = self.encode(prompt) + input_ids = self.encode(text) if truncation: input_ids = input_ids[:max_length] return input_ids - def encode(self, prompt: str) -> List[int]: + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` - return self.tokenizer.encode(prompt, bos=True, eos=False) + if add_special_tokens is not None: + return self.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + else: + return self.tokenizer.encode(text, bos=True, eos=False) def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], -- GitLab From 974dfd497149e871e59e35b677a85cca66ec3bae Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 12 Feb 2025 04:34:30 +0000 Subject: [PATCH 084/253] [Model] IBM/NASA Prithvi Geospatial model (#12830) --- .../prithvi_geospatial_mae.py | 530 ++++++++++++++++++ tests/models/registry.py | 4 + vllm/attention/backends/placeholder_attn.py | 11 +- vllm/inputs/preprocess.py | 22 +- .../models/prithvi_geospatial_mae.py | 238 ++++++++ vllm/model_executor/models/registry.py | 4 + vllm/worker/pooling_model_runner.py | 11 +- 7 files changed, 811 insertions(+), 9 deletions(-) create mode 100644 examples/offline_inference/prithvi_geospatial_mae.py create mode 100644 vllm/model_executor/models/prithvi_geospatial_mae.py diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py new file mode 100644 index 000000000..298f08019 --- /dev/null +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This is a demo script showing how to use the +PrithviGeospatialMAE model with vLLM +This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa + +Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa + +The requirements for running this script are: +- Installing [terratorch, albumentations, rasterio] in your python environment +- downloading the model weights in a 'model' folder local to the script + (temporary measure until the proper config.json file is uploaded to HF) +- download an input example image (India_900498_S2Hand.tif) and place it in + the same folder with the script (or specify with the --data_file argument) + +Run the example: +python prithvi_geospatial_mae.py + +""" # noqa: E501 +import argparse +import datetime +import os +import re +from typing import List, Union + +import albumentations +import numpy as np +import rasterio +import torch +from einops import rearrange +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + +from vllm import LLM + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +OFFSET = 0 +PERCENTILE = 99 + +model_config = """{ + "architectures": ["PrithviGeoSpatialMAE"], + "num_classes": 0, + "pretrained_cfg": { + "task_args": { + "task": "SemanticSegmentationTask", + "model_factory": "EncoderDecoderFactory", + "loss": "ce", + "ignore_index": -1, + "lr": 0.001, + "freeze_backbone": false, + "freeze_decoder": false, + "plot_on_val": 10, + "optimizer": "AdamW", + "scheduler": "CosineAnnealingLR" + }, + "model_args": { + "backbone_pretrained": false, + "backbone": "prithvi_eo_v2_300_tl", + "decoder": "UperNetDecoder", + "decoder_channels": 256, + "decoder_scale_modules": true, + "num_classes": 2, + "rescale": true, + "backbone_bands": [ + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2" + ], + "head_dropout": 0.1, + "necks": [ + { + "name": "SelectIndices", + "indices": [ + 5, + 11, + 17, + 23 + ] + }, + { + "name": "ReshapeTokensToImage" + } + ] + }, + "optimizer_params" : { + "lr": 5.0e-05, + "betas": [0.9, 0.999], + "eps": [1.0e-08], + "weight_decay": 0.05, + "amsgrad": false, + "maximize": false, + "capturable": false, + "differentiable": false + }, + "scheduler_params" : { + "T_max": 50, + "eta_min": 0, + "last_epoch": -1, + "verbose": "deprecated" + } + }, + + + "torch_dtype": "float32" +} +""" + +# Temporarily creating the "config.json" for the model. +# This is going to disappear once the correct config.json is available on HF +with open(os.path.join(os.path.dirname(__file__), "./model/config.json"), + 'w') as config_file: + config_file.write(model_config) + +datamodule_config = { + 'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], + 'batch_size': + 16, + 'constant_scale': + 0.0001, + 'data_root': + '/dccstor/geofm-finetuning/datasets/sen1floods11', + 'drop_last': + True, + 'no_data_replace': + 0.0, + 'no_label_replace': + -1, + 'num_workers': + 8, + 'test_transform': [ + albumentations.Resize(always_apply=False, + height=448, + interpolation=1, + p=1, + width=448), + albumentations.pytorch.ToTensorV2(transpose_mask=False, + always_apply=True, + p=1.0) + ], +} + + +class PrithviMAE: + + def __init__(self): + print("Initializing PrithviMAE model") + self.model = LLM(model=os.path.join(os.path.dirname(__file__), + "./model"), + skip_tokenizer_init=True, + dtype="float32") + + def run(self, input_data, location_coords): + print("################ Running inference on vLLM ##############") + # merge the inputs into one data structure + mm_data = { + "pixel_values": + torch.empty(0) if input_data is None else input_data, + "location_coords": + torch.empty(0) if location_coords is None else location_coords + } + + prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} + + outputs = self.model.encode(prompt, use_tqdm=False) + print( + "################ Inference done (it took seconds) ##############" + ) + + return outputs[0].outputs.data + + +def generate_datamodule(): + datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config['data_root'], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform" + ""]) + + return datamodule + + +def process_channel_group(orig_img, channels): + """ + Args: + orig_img: torch.Tensor representing original image (reference) + with shape = (bands, H, W). + channels: list of indices representing RGB channels. + + Returns: + torch.Tensor with shape (num_channels, height, width) for original image + """ + + orig_img = orig_img[channels, ...] + valid_mask = torch.ones_like(orig_img, dtype=torch.bool) + valid_mask[orig_img == NO_DATA_FLOAT] = False + + # Rescale (enhancing contrast) + max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) + min_value = OFFSET + + orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, + 1) + + # No data as zeros + orig_img[~valid_mask] = 0 + + return orig_img + + +def read_geotiff(file_path: str): + """Read all bands from *file_path* and return image + meta info. + + Args: + file_path: path to image file. + + Returns: + np.ndarray with shape (bands, height, width) + meta info dict + """ + + with rasterio.open(file_path) as src: + img = src.read() + meta = src.meta + try: + coords = src.lnglat() + except Exception: + # Cannot read coords + coords = None + + return img, meta, coords + + +def save_geotiff(image, output_path: str, meta: dict): + """Save multi-band image in Geotiff file. + + Args: + image: np.ndarray with shape (bands, height, width) + output_path: path where to save the image + meta: dict with meta info. + """ + + with rasterio.open(output_path, "w", **meta) as dest: + for i in range(image.shape[0]): + dest.write(image[i, :, :], i + 1) + + return + + +def _convert_np_uint8(float_image: torch.Tensor): + image = float_image.numpy() * 255.0 + image = image.astype(dtype=np.uint8) + + return image + + +def load_example( + file_paths: List[str], + mean: List[float] = None, + std: List[float] = None, + indices: Union[list[int], None] = None, +): + """Build an input example by loading images in *file_paths*. + + Args: + file_paths: list of file paths . + mean: list containing mean values for each band in the images + in *file_paths*. + std: list containing std values for each band in the images + in *file_paths*. + + Returns: + np.array containing created example + list of meta info for each image in *file_paths* + """ + + imgs = [] + metas = [] + temporal_coords = [] + location_coords = [] + + for file in file_paths: + img, meta, coords = read_geotiff(file) + + # Rescaling (don't normalize on nodata) + img = np.moveaxis(img, 0, -1) # channels last for rescaling + if indices is not None: + img = img[..., indices] + if mean is not None and std is not None: + img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) + + imgs.append(img) + metas.append(meta) + if coords is not None: + location_coords.append(coords) + + try: + match = re.search(r'(\d{7,8}T\d{6})', file) + if match: + year = int(match.group(1)[:4]) + julian_day = match.group(1).split('T')[0][4:] + if len(julian_day) == 3: + julian_day = int(julian_day) + else: + julian_day = datetime.datetime.strptime( + julian_day, '%m%d').timetuple().tm_yday + temporal_coords.append([year, julian_day]) + except Exception as e: + print(f'Could not extract timestamp for {file} ({e})') + + imgs = np.stack(imgs, axis=0) # num_frames, H, W, C + imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.expand_dims(imgs, axis=0) # add batch di + + return imgs, temporal_coords, location_coords, metas + + +def run_model(input_data, + temporal_coords, + location_coords, + model, + datamodule, + img_size, + lightning_model=None): + # Reflect pad if not divisible by img_size + original_h, original_w = input_data.shape[-2:] + pad_h = (img_size - (original_h % img_size)) % img_size + pad_w = (img_size - (original_w % img_size)) % img_size + input_data = np.pad(input_data, + ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), + mode="reflect") + + # Build sliding window + batch_size = 1 + batch = torch.tensor(input_data, device="cpu") + windows = (batch.unfold(3, img_size, + img_size).unfold(4, img_size, img_size)) + h1, w1 = windows.shape[3:5] + windows = rearrange(windows, + "b c t h1 w1 h w -> (b h1 w1) c t h w", + h=img_size, + w=img_size) + + # Split into batches if number of windows > batch_size + num_batches = windows.shape[0] // batch_size if windows.shape[ + 0] > batch_size else 1 + windows = torch.tensor_split(windows, num_batches, dim=0) + + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + if temporal_coords: + temporal_coords = torch.tensor(temporal_coords, + device=device).unsqueeze(0) + else: + temporal_coords = None + if location_coords: + location_coords = torch.tensor(location_coords[0], + device=device).unsqueeze(0) + else: + location_coords = None + + # Run model + pred_imgs = [] + for x in windows: + # Apply standardization + x = datamodule.test_transform( + image=x.squeeze().numpy().transpose(1, 2, 0)) + x = datamodule.aug(x)['image'] + + with torch.no_grad(): + x = x.to(device) + pred = model.run(x, location_coords=location_coords) + if lightning_model: + pred_lightning = lightning_model( + x, + temporal_coords=temporal_coords, + location_coords=location_coords) + pred_lightning = pred_lightning.output.detach().cpu() + if not torch.equal(pred, pred_lightning): + print("Inference output is not equal") + y_hat = pred.argmax(dim=1) + + y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), + size=img_size, + mode="nearest") + + pred_imgs.append(y_hat) + + pred_imgs = torch.concat(pred_imgs, dim=0) + + # Build images from patches + pred_imgs = rearrange( + pred_imgs, + "(b h1 w1) c h w -> b c (h1 h) (w1 w)", + h=img_size, + w=img_size, + b=1, + c=1, + h1=h1, + w1=w1, + ) + + # Cut padded area back to original size + pred_imgs = pred_imgs[..., :original_h, :original_w] + + # Squeeze (batch size 1) + pred_imgs = pred_imgs[0] + + return pred_imgs + + +def main( + data_file: str, + output_dir: str, + rgb_outputs: bool, + input_indices: list[int] = None, +): + os.makedirs(output_dir, exist_ok=True) + + # Load model --------------------------------------------------------------- + + model_obj = PrithviMAE() + datamodule = generate_datamodule() + img_size = 256 # Size of Sen1Floods11 + + # Loading data ------------------------------------------------------------- + + input_data, temporal_coords, location_coords, meta_data = load_example( + file_paths=[data_file], + indices=input_indices, + ) + + meta_data = meta_data[0] # only one image + + if input_data.mean() > 1: + input_data = input_data / 10000 # Convert to range 0-1 + + # Running model ------------------------------------------------------------ + + channels = [ + datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"] + ] # BGR -> RGB + + pred = run_model(input_data, temporal_coords, location_coords, model_obj, + datamodule, img_size) + + # Save pred + meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) + pred_file = os.path.join( + output_dir, + f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) + + # Save image + pred + meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0) + + if input_data.mean() < 1: + input_data = input_data * 10000 # Scale to 0-10000 + + rgb_orig = process_channel_group( + orig_img=torch.Tensor(input_data[0, :, 0, ...]), + channels=channels, + ) + + pred[pred == 0.] = np.nan + img_pred = rgb_orig * 0.7 + pred * 0.3 + img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] + + img_pred_file = os.path.join( + output_dir, + f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff( + image=_convert_np_uint8(img_pred), + output_path=img_pred_file, + meta=meta_data, + ) + + # Save image rgb + if rgb_outputs: + rgb_file = os.path.join( + output_dir, "original_rgb_" + f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + save_geotiff( + image=_convert_np_uint8(rgb_orig), + output_path=rgb_file, + meta=meta_data, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help= + "0-based indices of the six Prithvi channels to be selected from the " + "input. By default selects [1,2,3,8,11,12] for S2L1C data.", + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() + + main(**vars(args)) diff --git a/tests/models/registry.py b/tests/models/registry.py index 66b7d3c2e..7b1db5549 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = { "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 + # The model on Huggingface is currently being updated, + # hence I temporarily mark it as not available online + "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + is_available_online=False), } _CROSS_ENCODER_EXAMPLE_MODELS = { diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index f363ba0c1..f1def25c8 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder( -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + + # Some input builders such as ModelInputForCPUBuilder do not have the + # "inter_data_list" attribute. + # Let's check inter_data_list exists before we reference it. + if hasattr(self.input_builder, "inter_data_list"): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 53f89996f..656f2f2b7 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -254,8 +254,14 @@ class InputPreprocessor: Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - tokenizer_group = self.get_tokenizer_group() - tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = None + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) @@ -273,9 +279,15 @@ class InputPreprocessor: lora_request: Optional[LoRARequest], ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" - tokenizer_group = self.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request - ) + # At the moment on model (PrithviGeoSpatialMAE) requires to be + # initialized without a tokenizer while using also multi-modal + # input. + if not self.tokenizer: + tokenizer = None + else: + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async( + lora_request) mm_processor = self.mm_registry.create_processor( self.model_config, tokenizer) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py new file mode 100644 index 000000000..9383cbae1 --- /dev/null +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM/NASA Prithvi Geospatial model.""" +from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (IsAttentionFree, + SupportsMultiModal) +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) + + +class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + pass + + +class PrithviGeoSpatialMAEInputBuilder( + BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + return ProcessorInputs( + prompt_text="", + # This model input is fixed and is in the form of a torch Tensor. + # The size of pixel_values might change in the cases where we resize + # the input but never exceeds the dimensions below. + mm_data={ + "pixel_values": torch.full((1, 6, 512, 512), 1.0), + "location_coords": torch.full((1, 2), 1.0) + }) + + +class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + location_coords=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + pass + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + pass + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputs: + mm_kwargs = {} + + for k, v in mm_data.items(): + mm_kwargs[k] = v + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=[1], + mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_placeholders={}, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + PrithviGeoSpatialMAEMultiModalProcessor, + info=PrithviGeoSpatialMAEProcessingInfo, + dummy_inputs=PrithviGeoSpatialMAEInputBuilder) +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): + """ Prithvi Masked Autoencoder""" + + def _instantiate_model(self, config: dict) -> nn.Module | None: + + # We might be able/need to support different tasks with this same model + if config["task_args"]["task"] == "SemanticSegmentationTask": + from terratorch.cli_tools import SemanticSegmentationTask + task = SemanticSegmentationTask( + config["model_args"], + config["task_args"]["model_factory"], + loss=config["task_args"]["loss"], + lr=config["task_args"]["lr"], + ignore_index=config["task_args"]["ignore_index"], + optimizer=config["task_args"]["optimizer"], + optimizer_hparams=config["optimizer_params"], + scheduler=config["task_args"]["scheduler"], + scheduler_hparams=config["scheduler_params"], + plot_on_val=config["task_args"]["plot_on_val"], + freeze_decoder=config["task_args"]["freeze_decoder"], + freeze_backbone=config["task_args"]["freeze_backbone"]) + + return task.model + else: + return None + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + # the actual model is dynamically instantiated using terratorch + # allowing us to perform changes to the model architecture + # at startup time (e.g., change the model decoder class.) + self.model = self._instantiate_model( + vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) + if self.model is None: + raise ValueError( + "Unsupported task." + "Only SemanticSegmentationTask is supported for now" + "by PrithviGeospatialMAE.") + + def _parse_and_validate_multimodal_data( + self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: + + pixel_values = kwargs.pop("pixel_values", None) + if not isinstance(pixel_values, torch.Tensor): + raise ValueError(f"Incorrect type of pixel_values. " + f"Got type: {type(pixel_values)}") + pixel_values = torch.unbind(pixel_values, dim=0)[0] + + location_coords = kwargs.pop("location_coords", None) + if not isinstance(location_coords, torch.Tensor): + raise ValueError(f"Incorrect type of location_coords. " + f"Got type: {type(location_coords)}") + location_coords = torch.unbind(location_coords, dim=0)[0] + if location_coords.shape == torch.Size([0]): + location_coords = None + + return pixel_values, location_coords + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + + pixel_values, location_coords = ( + self._parse_and_validate_multimodal_data(**kwargs)) + model_output = self.model(pixel_values, + location_coords=location_coords) + + return model_output.output + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c2d0fae70..ebf6a88f2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -137,6 +137,10 @@ _EMBEDDING_MODELS = { "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 # [Auto-converted (see adapters.py)] "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), + # Technically PrithviGeoSpatialMAE is a model that works on images, both in + # input and output. I am adding it here because it piggy-backs on embedding + # models for the time being. + "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), } _CROSS_ENCODER_MODELS = { diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f43085b0e..4cbe5db44 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -74,7 +74,16 @@ class PoolingModelRunner( prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata virtual_engine = model_input.virtual_engine - if prefill_meta is None and decode_meta.use_cuda_graph: + # Pooling models are (ab-)used also to integrate non text models that + # are not autoregressive (PrithviGeosaptialMAE). + # These model might not use attention and do not really have a prefill + # and decode phase. The model input is processed in one shot and both + # decode_metadata and prefill_metadata would be None for such models. + # See the PlaceholderAttentionMetadata class. + # TODO: Figure out if cuda_graph is of any use for these models and + # explore how to leverage it. + if (prefill_meta is None and decode_meta is not None + and decode_meta.use_cuda_graph): assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ -- GitLab From 842b0fd402574f49f8828fcff1b8dacc3bcab5fa Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Tue, 11 Feb 2025 20:38:10 -0800 Subject: [PATCH 085/253] [ci] Add more source file dependencies for some tests (#13123) Signed-off-by: <> Co-authored-by: EC2 Default User --- .buildkite/test-pipeline.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 948eab97f..e26b1bf38 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -107,6 +107,10 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/ + - tests/entrypoints/llm + - tests/entrypoints/openai + - tests/entrypoints/test_chat_utils + - tests/entrypoints/offline_mode commands: - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process @@ -124,9 +128,10 @@ steps: source_file_dependencies: - vllm/distributed/ - vllm/core/ - - tests/distributed + - tests/distributed/test_utils + - tests/distributed/test_pynccl - tests/spec_decode/e2e/test_integration_dist_tp4 - - tests/compile + - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py commands: @@ -174,6 +179,9 @@ steps: - vllm/ - tests/engine - tests/tokenization + - tests/test_sequence + - tests/test_config + - tests/test_logger commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py # OOM in the CI unless we run this separately -- GitLab From e92694b6fe264a85371317295bca6643508034ef Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Tue, 11 Feb 2025 21:12:37 -0800 Subject: [PATCH 086/253] [Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921) Signed-off-by: Lingfan Yu --- tests/neuron/test_prefix_prefill.py | 118 ++++++++------- vllm/attention/ops/nki_flash_attn.py | 216 +++++++++++---------------- 2 files changed, 154 insertions(+), 180 deletions(-) diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py index dfbcfc15e..04d1bd3f0 100644 --- a/tests/neuron/test_prefix_prefill.py +++ b/tests/neuron/test_prefix_prefill.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import random from typing import Optional import pytest @@ -171,12 +170,22 @@ def ref_context_attention( return output +@pytest.mark.parametrize( + "block_size, large_tile_size", + [ + (32, 2048), # 64 blocks + (32, 4096), # 128 blocks + (32, 8192), # 256 blocks + (64, 8192), # 128 blocks + ], +) @pytest.mark.parametrize( "num_heads,num_queries_per_kv,head_size,mixed_precision", [ (4, 2, 8, False), (4, 2, 8, True), (32, 8, 64, True), + (16, 2, 128, True), ], ) @torch.inference_mode() @@ -184,6 +193,8 @@ def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, + block_size: int, + large_tile_size, mixed_precision: bool, ) -> None: import os @@ -192,40 +203,46 @@ def test_contexted_kv_attention( from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc + assert large_tile_size % block_size == 0 + device = xm.xla_device() - os.environ["NEURON_CC_FLAGS"] = ( - " --model-type=transformer -O1 " - " --internal-hlo2tensorizer-options='--verify-hlo' ") + compiler_flags = [ + "--model-type=transformer -O1", + "--internal-hlo2tensorizer-options='--verify-hlo'", + "--retry_failed_compilation", + ] + compiler_flags_str = " ".join(compiler_flags) + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str - random.seed(0) torch.manual_seed(0) torch.set_printoptions(sci_mode=False) - min_ctx_len = 2 - max_ctx_len = 64 - min_query_len = 2 - max_query_len = 64 - prefill_batch_size = 2 - decode_batch_size = 6 + min_ctx_len = 32 + max_ctx_len = 1024 + min_query_len = 16 + max_query_len = 512 + prefill_batch_size = 4 + decode_batch_size = 12 batch_size = prefill_batch_size + decode_batch_size - block_size = 32 max_model_len = (max_query_len + max_ctx_len) * 4 max_block_per_request = max_model_len // block_size dtype = torch.float32 cache_size = (batch_size * max_block_per_request) + 2 - ctx_lens = [ - random.randint(min_ctx_len, max_ctx_len) - for _ in range(prefill_batch_size) - ] + [ - random.randint(min_ctx_len, max_ctx_len) - for _ in range(decode_batch_size) - ] - query_lens = [ - random.randint(min_query_len, max_query_len) - for _ in range(prefill_batch_size) - ] + [1 for _ in range(decode_batch_size)] + prefill_ctx_lens = torch.randint(min_ctx_len, + max_ctx_len + 1, (prefill_batch_size, ), + dtype=torch.long).tolist() + decode_ctx_lens = torch.randint(min_ctx_len, + max_ctx_len + 1, (decode_batch_size, ), + dtype=torch.long).tolist() + ctx_lens = prefill_ctx_lens + decode_ctx_lens + query_lens = torch.randint( + min_query_len, + max_query_len + 1, + (prefill_batch_size, ), + dtype=torch.long, + ).tolist() + [1 for _ in range(decode_batch_size)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv @@ -254,7 +271,6 @@ def test_contexted_kv_attention( values = values[torch.randperm(cache_size)] block_table = values[:batch_size * max_block_per_request].view( batch_size, max_block_per_request) - torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), @@ -311,9 +327,7 @@ def test_contexted_kv_attention( # build neuron program return_debug_tensors = False B_P_SIZE = 128 - LARGE_TILE_SZ = 2048 - max_num_queries = ( - (sum(query_lens) + block_size - 1) // block_size) * block_size + LARGE_TILE_SZ = large_tile_size def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, num_blocks): @@ -332,26 +346,28 @@ def test_contexted_kv_attention( 0, ) - def shift_bit_length(x): - return 1 << (x - 1).bit_length() + def ceil_div(a, b): + return (a + b - 1) // b + + def pad_to_multiple(a, b): + return ceil_div(a, b) * b + + def pad_to_next_power_of_2(a): + assert a > 0 + return 2**int(a - 1).bit_length() # calculate input shapes - max_num_queries_shifted = shift_bit_length(max_num_queries) - max_num_queries_factor = B_P_SIZE // max_num_queries_shifted - max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor - assert (max_num_queries_padded == B_P_SIZE - ), "invalid {max_num_queries_padded=}" + max_num_queries = pad_to_multiple(sum(query_lens), block_size) + max_num_queries = pad_to_next_power_of_2(max_num_queries) head_size_padded = B_P_SIZE + assert head_size_padded >= head_size context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - num_active_blocks_shifted = shift_bit_length( - ((context_lens + block_size - 1) // block_size).sum().item()) - num_active_blocks_factor = (LARGE_TILE_SZ // block_size // - num_active_blocks_shifted) - num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor - assert (num_active_blocks * - block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}" + num_active_blocks = ceil_div(context_lens, block_size).sum().item() + num_active_blocks = pad_to_multiple(num_active_blocks, + LARGE_TILE_SZ // block_size) context_kv_len = num_active_blocks * block_size - assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}" + assert (context_kv_len % + LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}" # pad QKV tensors pad_dims = ( @@ -360,7 +376,7 @@ def test_contexted_kv_attention( 0, 0, 0, - max_num_queries_padded - query.shape[0], + max_num_queries - query.shape[0], ) query = F.pad(query, pad_dims, "constant", 0) k = F.pad(k, pad_dims, "constant", 0) @@ -397,7 +413,7 @@ def test_contexted_kv_attention( 0, context_kv_len - prior_mask.shape[1], 0, - B_P_SIZE - prior_mask.shape[0], + max_num_queries - prior_mask.shape[0], ), "constant", 0, @@ -406,9 +422,9 @@ def test_contexted_kv_attention( active_mask, ( 0, - B_P_SIZE - active_mask.shape[1], + max_num_queries - active_mask.shape[1], 0, - B_P_SIZE - active_mask.shape[0], + max_num_queries - active_mask.shape[0], ), "constant", 0, @@ -430,6 +446,8 @@ def test_contexted_kv_attention( n_kv_head=num_kv_heads, head_size=head_size, mixed_precision=mixed_precision, + LARGE_TILE_SZ=LARGE_TILE_SZ, + return_debug_tensors=return_debug_tensors, ) if return_debug_tensors: @@ -439,17 +457,15 @@ def test_contexted_kv_attention( output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) debug_tensors = [] - output_nki = torch.tensor(output_nki).cpu() debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors] num_actual_tokens = sum(query_lens) - print(f"{num_actual_tokens=}") # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.permute( - 0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :] + output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size] + output_nki = output_nki[0, :num_actual_tokens, :, :] output_ref_padded = F.pad( output_ref, - (0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]), + (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]), "constant", 0, ) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index 68aa63f5a..5e2a1f7e6 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -28,7 +28,6 @@ class FlashConfig: def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ, - forward_mask, B_F_SIZE=512): for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): if nisa.get_nc_version() == nisa.nc_version.gen3: @@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed, if nisa.get_nc_version() == nisa.nc_version.gen3: p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( - p_local[:, i_j_128_slice], mask=forward_mask) + p_local[:, i_j_128_slice]) else: p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( - p_local[:, i_j_128_slice], mask=forward_mask) + p_local[:, i_j_128_slice]) p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( - p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask) + p_local_t_tmp, dtype=p_local_transposed.dtype) @nki.jit @@ -60,36 +59,25 @@ def _flash_attention_core( q_local_tile, k, v, - q_h_per_k_h, - seqlen_q, - nheads, o_buffer, l_buffer, m_buffer, - batch_id, - head_id, - gqa_head_idx, q_tile_idx, - local_k_large_tile_idx, kernel_dtype, acc_type, flash_config: FlashConfig, - use_causal_mask=False, - continuous_batching_mask=None, + use_causal_mask, + tile_mask, initialize=False, B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, - dropout_p=0.0, - dropout_p_tensor=None, - seed_tensor=None, - logit_bias_tile=None, qk_res_buffer=None, ): """ The flash attention core function to calculate self attention between a tile of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF + The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V is defined in the seq_tile_size of the flash_config. The results are stored in the following three buffers @@ -99,24 +87,9 @@ def _flash_attention_core( """ LARGE_TILE_SZ = flash_config.seq_tile_size num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - seqlen_k = k.shape[-1] - seqlen_q // B_P_SIZE - seqlen_k // B_F_SIZE - - # TODO : support logit_bias with continuous_batching_mask - assert not use_causal_mask, "causal mask is not supported." - assert (continuous_batching_mask - is not None), "continuous_batching_mask input is required." - if continuous_batching_mask is not None: - assert ( - logit_bias_tile - is None), "continuous_batching_mask does not support logit_bias!" # mask are used to only apply computation to the lower half of the matrix, # which reduce the arithmetic intensity by half - forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * - LARGE_TILE_SZ if use_causal_mask else None) - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) @@ -125,20 +98,27 @@ def _flash_attention_core( for k_i in nl.affine_range(num_k_tile_per_large_tile): k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) - qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE), - dtype=np.float32, - buffer=nl.psum) # (128, 512) - qk_psum[:, :] = nl.matmul(q_local_tile, - k[:, k_i_b_f_slice], - transpose_x=True, - mask=None) # (p(128), 512) - - qk_res_buf[:, k_i_b_f_slice] = nl.where( - continuous_batching_mask[:, k_i_b_f_slice], - qk_psum[:, nl.ds(0, B_F_SIZE)], - -9984.0, - dtype=acc_type, - ) + if use_causal_mask: + multiplication_required_selection = (q_tile_idx * B_P_SIZE + >= k_i * B_F_SIZE) + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True) # (p(128), 512) + qk_res_buf[:, k_i_b_f_slice] = nl.where( + tile_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + else: + qk_res_buf[:, k_i_b_f_slice] = -9984.0 # Calculate max of the current tile max_local[:, k_i] = nisa.tensor_reduce( @@ -147,7 +127,6 @@ def _flash_attention_core( axis=(1, ), dtype=acc_type, negate=False, - mask=forward_mask, ) if qk_res_buffer is not None: @@ -159,7 +138,6 @@ def _flash_attention_core( axis=(1, ), dtype=acc_type, negate=False, - mask=forward_mask, ) o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), @@ -170,8 +148,7 @@ def _flash_attention_core( m_current = max_ else: m_previous = nl.copy(m_buffer[:, 0]) - m_buffer[:, 0] = nl.maximum(m_previous, max_, - mask=forward_mask) # (128,1) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) m_current = m_buffer[:, 0] # Compute scaling factor @@ -180,11 +157,8 @@ def _flash_attention_core( m_previous, bias=-1 * m_current, scale=1.0, - mask=forward_mask, ) - o_previous_scaled[...] = nl.multiply(o_buffer[:, :], - alpha, - mask=forward_mask) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) @@ -207,10 +181,9 @@ def _flash_attention_core( reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i], dtype=kernel_dtype, - mask=forward_mask, ) - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) @@ -218,7 +191,6 @@ def _flash_attention_core( p_local_transposed=p_local_transposed, p_local=p_local, LARGE_TILE_SZ=LARGE_TILE_SZ, - forward_mask=forward_mask, B_F_SIZE=B_F_SIZE, ) @@ -230,27 +202,20 @@ def _flash_attention_core( p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], v[k_i, :, :], transpose_x=True, - mask=forward_mask, ) # (128, 128) (p(Br), d) if initialize: o_buffer[:, :] = nl.copy(pv_psum[:, :]) l_buffer[:, 0] = nl.add(nl.log(ps), max_) else: - o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask) + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) l_prev = l_buffer[:, 0] l_exp = nl.add( - nl.exp( - nl.subtract(l_prev, m_current, mask=forward_mask), - mask=forward_mask, - ), + nl.exp(nl.subtract(l_prev, m_current)), ps, - mask=forward_mask, ) - l_buffer[:, 0] = nl.add(m_current, - nl.log(l_exp, mask=forward_mask), - mask=forward_mask) + l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) @nki.jit @@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): ) +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles): + (num_blocks, ) = block_tables_hbm.shape + assert num_blocks % num_tiles == 0 + num_blocks_per_tile = num_blocks // num_tiles + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32) + return block_tables_buffer + + +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + @nki.jit def flash_paged_attention( query, @@ -316,24 +296,24 @@ def flash_paged_attention( - We use paged cache blocks (key_cache, value_cache) to store KV cache. IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype except for + - This kernel assumes all IO tensors have the same dtype except for block_tables (int32) and mask (int32) - - If mixed_percision is True, then all Tensor Engine operation will be - performed in bfloat16 and accumulation will be performed in float32. + - If mixed_percision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. Compile-time Constants: - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - mixed_precision: flag to set non-matmul ops in fp32 precision, default - is set to `true`, if false, we use same precision as input types + is set to `true`, if false, we use same precision as input types - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values - seq_tile_size: `default=2048`, size of the kv tile size for attention + seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction GQA support Notes: - the spmd kernel for launching kernel should be on kv_heads instead of + the spmd kernel for launching kernel should be on kv_heads instead of nheads Example usage: @@ -415,18 +395,13 @@ def flash_paged_attention( ), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}" num_large_k_tile = context_kv_len // LARGE_TILE_SZ num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert (num_blocks_per_large_tile <= B_P_SIZE - ), f"The number of blocks in each large tile " \ - f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}" - - block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile), - 0, - dtype=np.int32, - buffer=nl.sbuf) - for j in nl.affine_range(num_large_k_tile): - i_p = nl.arange(num_blocks_per_large_tile)[:, None] - block_tables_sbuf[i_p, j] = nl.load( - block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32) + assert block_size % 32 == 0, "block_size is expected to be a multiple of 32" + assert is_power_of_2( + num_blocks_per_large_tile + ), "The number of blocks in each large tile is expected of be power of 2" + assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2" + + block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile) # Global Flash Attention accumulators o_buffer = nl.zeros( @@ -457,7 +432,7 @@ def flash_paged_attention( ) for k_i in nl.affine_range(num_blocks_per_large_tile): - loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :, + loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :, head_id, :]) cur_k_tile[:, nl.ds(k_i * block_size, block_size)] = nl.transpose(loaded) @@ -469,7 +444,7 @@ def flash_paged_attention( num_blocks_per_partition): v_i = (partition_idx * num_blocks_per_partition + block_in_partition) - loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :, + loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :, head_id, :]) cur_v_tile[ partition_idx, @@ -477,14 +452,15 @@ def flash_paged_attention( :, ] = loaded_v - cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=mask.dtype) - for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load( - mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)]) - - for i_q_h in nl.affine_range(q_h_per_k_h): - for i in nl.affine_range(n_tile_q): + for i in nl.affine_range(n_tile_q): + cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=mask.dtype) + for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] q_sbuf_tile = nl.load( @@ -497,35 +473,24 @@ def flash_paged_attention( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, - seqlen_q=seqlen_q, - nheads=h, o_buffer=o_buffer[i, i_q_h], l_buffer=l_buffer[:, i, i_q_h], m_buffer=m_buffer[i, i_q_h], - batch_id=batch_id, - head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, - local_k_large_tile_idx=j, kernel_dtype=kernel_dtype, acc_type=acc_type, flash_config=config, use_causal_mask=False, - continuous_batching_mask=cur_mask, + tile_mask=cur_mask, initialize=j == 0, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=0.0, - dropout_p_tensor=None, - seed_tensor=None, - logit_bias_tile=None, ) # compute attention between input query, key and value if key is not None and value is not None: - B_F_SIZE = seqlen_q + B_F_SIZE = min(seqlen_q, B_F_SIZE) LARGE_TILE_SZ = seqlen_q active_config = FlashConfig( seq_tile_size=LARGE_TILE_SZ, @@ -552,11 +517,16 @@ def flash_paged_attention( config=active_config, ) - cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype) - cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)]) + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load( + mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ], + dtype=mask.dtype, + ) + for i_q_h in nl.affine_range(q_h_per_k_h): - for i_q_h in nl.affine_range(q_h_per_k_h): - for i in nl.affine_range(n_tile_q): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] q_sbuf_tile = nl.load( @@ -568,32 +538,21 @@ def flash_paged_attention( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, - seqlen_q=seqlen_q, - nheads=h, o_buffer=o_buffer[i, i_q_h], l_buffer=l_buffer[:, i, i_q_h], m_buffer=m_buffer[i, i_q_h], - batch_id=batch_id, - head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, - local_k_large_tile_idx=0, kernel_dtype=kernel_dtype, acc_type=acc_type, flash_config=active_config, - use_causal_mask=False, - continuous_batching_mask=cur_mask, + use_causal_mask=True, + tile_mask=cur_mask, initialize=False, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=0.0, - dropout_p_tensor=None, - seed_tensor=None, - logit_bias_tile=None, - qk_res_buffer=qk_res_buffer[i, i_q_h] - if qk_res_buffer is not None else None, + qk_res_buffer=(qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None), ) # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # @@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc( attn_mask, n_kv_head=None, head_size=None, - B_P_SIZE=128, LARGE_TILE_SZ=2048, return_debug_tensors=False, mixed_precision=True, -- GitLab From a0597c6b7534c383accab86f9967176a7ece4aae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 00:40:19 -0800 Subject: [PATCH 087/253] Bump helm/kind-action from 1.10.0 to 1.12.0 (#11612) --- .github/workflows/lint-and-deploy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 556b60d2f..9d2e54ce9 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -47,7 +47,7 @@ jobs: aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - name: Create kind cluster - uses: helm/kind-action@0025e74a8c7512023d06dc019c617aa3cf561fde # v1.10.0 + uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 - name: Build the Docker image vllm cpu run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env . -- GitLab From dd3b4a01f84131c1c3d94d0bf0cbf95a98eec586 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 00:40:25 -0800 Subject: [PATCH 088/253] Bump actions/stale from 9.0.0 to 9.1.0 (#12462) --- .github/workflows/stale.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 81e7c9b05..656f3d3fa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@28ca1036281a5e5922ead5184a1bbf96e5fc984e # v9.0.0 + - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months -- GitLab From 0c7d9effce8121d769e932a22e5753749e826f60 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 16:41:06 +0800 Subject: [PATCH 089/253] Bump helm/chart-testing-action from 2.6.1 to 2.7.0 (#12463) --- .github/workflows/lint-and-deploy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 9d2e54ce9..99365c67c 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -22,7 +22,7 @@ jobs: python-version: '3.13' - name: Set up chart-testing - uses: helm/chart-testing-action@e6669bcd63d7cb57cb4380c33043eebe5d111992 # v2.6.1 + uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 with: version: v3.10.1 -- GitLab From d59def47305487ca523379dd97073f5ab037d663 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 16:41:22 +0800 Subject: [PATCH 090/253] Bump actions/setup-python from 5.3.0 to 5.4.0 (#12672) --- .github/workflows/cleanup_pr_body.yml | 2 +- .github/workflows/lint-and-deploy.yaml | 2 +- .github/workflows/pre-commit.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index 0085a1cc2..50fea0c43 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: '3.12' diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 99365c67c..a4e9acc41 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -17,7 +17,7 @@ jobs: version: v3.14.4 #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: '3.13' diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 06564969d..dc10b9116 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" -- GitLab From 7c4033acd46749f9da48aa8e347648d47e2f4876 Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Wed, 12 Feb 2025 07:34:09 -0300 Subject: [PATCH 091/253] Further reduce the HTTP calls to huggingface.co (#13107) --- vllm/transformers_utils/config.py | 135 +++++++++++++++++------------- 1 file changed, 79 insertions(+), 56 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index aade28610..4b76509e4 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,12 +4,14 @@ import enum import json import os import time +from functools import cache from pathlib import Path -from typing import Any, Dict, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Type, Union import huggingface_hub -from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, - try_to_load_from_cache) +from huggingface_hub import hf_hub_download +from huggingface_hub import list_repo_files as hf_list_repo_files +from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, @@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" +def with_retry(func: Callable[[], Any], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2): + for attempt in range(max_retries): + try: + return func() + except Exception as e: + if attempt == max_retries - 1: + logger.error("%s: %s", log_msg, e) + raise + logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, + max_retries) + time.sleep(retry_delay) + retry_delay *= 2 + + +# @cache doesn't cache exceptions +@cache +def list_repo_files( + repo_id: str, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> list[str]: + + def lookup_files(): + try: + return hf_list_repo_files(repo_id, + revision=revision, + repo_type=repo_type, + token=token) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, + # all we know is that we don't have this + # file cached. + return [] + + return with_retry(lookup_files, "Error retrieving file list") + + +def file_exists( + repo_id: str, + file_name: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, +) -> bool: + + file_list = list_repo_files(repo_id, + repo_type=repo_type, + revision=revision, + token=token) + return file_name in file_list + + +# In offline mode the result can be a false negative def file_or_path_exists(model: Union[str, Path], config_name: str, revision: Optional[str]) -> bool: if Path(model).exists(): @@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # hf_hub. This will fail in offline mode. # Call HF to check if the file exists - # 2 retries and exponential backoff - max_retries = 2 - retry_delay = 2 - for attempt in range(max_retries): - try: - return file_exists(model, - config_name, - revision=revision, - token=HF_TOKEN) - except huggingface_hub.errors.OfflineModeIsEnabled: - # Don't raise in offline mode, - # all we know is that we don't have this - # file cached. - return False - except Exception as e: - logger.error( - "Error checking file existence: %s, retrying %d of %d", e, - attempt + 1, max_retries) - if attempt == max_retries - 1: - logger.error("Error checking file existence: %s", e) - raise - time.sleep(retry_delay) - retry_delay *= 2 - continue - return False + return file_exists(str(model), + config_name, + revision=revision, + token=HF_TOKEN) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -208,32 +248,7 @@ def get_config( revision=revision): config_format = ConfigFormat.MISTRAL else: - # If we're in offline mode and found no valid config format, then - # raise an offline mode error to indicate to the user that they - # don't have files cached and may need to go online. - # This is conveniently triggered by calling file_exists(). - - # Call HF to check if the file exists - # 2 retries and exponential backoff - max_retries = 2 - retry_delay = 2 - for attempt in range(max_retries): - try: - file_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=HF_TOKEN) - except Exception as e: - logger.error( - "Error checking file existence: %s, retrying %d of %d", - e, attempt + 1, max_retries) - if attempt == max_retries: - logger.error("Error checking file existence: %s", e) - raise e - time.sleep(retry_delay) - retry_delay *= 2 - - raise ValueError(f"No supported config format found in {model}") + raise ValueError(f"No supported config format found in {model}.") if config_format == ConfigFormat.HF: config_dict, _ = PretrainedConfig.get_config_dict( @@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str, file_name=file_name, revision=revision) - if file_path is None and file_or_path_exists( - model=model, config_name=file_name, revision=revision): + if file_path is None: try: hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except huggingface_hub.errors.OfflineModeIsEnabled: + return None except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e: logger.debug("File or repository not found in hf_hub_download", e) @@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str, return None +@cache def get_pooling_config(model: str, revision: Optional[str] = 'main'): """ This function gets the pooling and normalize @@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): if modules_dict is None: return None + logger.info("Found sentence-transformers modules configuration.") + pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None) @@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): if pooling_type_name is not None: pooling_type_name = get_pooling_config_name(pooling_type_name) + logger.info("Found pooling configuration.") return {"pooling_type": pooling_type_name, "normalize": normalize} return None @@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: return None +@cache def get_sentence_transformer_tokenizer_config(model: str, revision: Optional[str] = 'main' ): @@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str, if not encoder_dict: return None + logger.info("Found sentence-transformers tokenize configuration.") + if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")): return encoder_dict return None -- GitLab From f1042e86f05cfe93bcadac445e78671ed2e8fddb Mon Sep 17 00:00:00 2001 From: Shiyan Deng <842974287@qq.com> Date: Wed, 12 Feb 2025 02:36:10 -0800 Subject: [PATCH 092/253] [Misc] AMD Build Improvements (#12923) --- csrc/moe/moe_align_sum_kernels.cu | 2 +- csrc/rocm/attention.cu | 2 +- vllm/model_executor/models/registry.py | 15 +++++++++++---- vllm/transformers_utils/configs/__init__.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 01dac4044..c072744f0 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -3,7 +3,7 @@ #include #include -#include +#include #include "../cuda_compat.h" #include "../dispatch_utils.h" diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index ffa9d4461..366b3cdc2 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1122,4 +1122,4 @@ void paged_attention( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ebf6a88f2..198b6d134 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -205,6 +205,14 @@ _VLLM_MODELS = { **_FALLBACK_MODEL, } +# This variable is used as the args for subprocess.run(). We +# can modify this variable to alter the args if needed. e.g. +# when we use par format to pack things together, sys.executable +# might not be the target we want to run. +_SUBPROCESS_COMMAND = [ + sys.executable, "-m", "vllm.model_executor.models.registry" +] + @dataclass(frozen=True) class _ModelInfo: @@ -502,10 +510,9 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: # cannot use `sys.executable __file__` here because the script # contains relative imports - returned = subprocess.run( - [sys.executable, "-m", "vllm.model_executor.models.registry"], - input=input_bytes, - capture_output=True) + returned = subprocess.run(_SUBPROCESS_COMMAND, + input=input_bytes, + capture_output=True) # check if the subprocess is successful try: diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index c484a755a..906056559 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -45,4 +45,4 @@ __all__ = [ "SolarConfig", "Telechat2Config", "UltravoxConfig", -] \ No newline at end of file +] -- GitLab From f4d97e4fc276b13e1a4ec18f35239fd48695667d Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 12 Feb 2025 05:39:16 -0500 Subject: [PATCH 093/253] [Bug] [V1] Try fetching stop_reason from EngineOutput before checking the request (#13108) --- vllm/v1/engine/output_processor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 7973c62c3..1438f9d5a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind @@ -164,6 +164,7 @@ class OutputProcessor: new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason + stop_reason = engine_core_output.stop_reason # TODO(andy): prompt logprobs + chunked prefill can # result in engine core returning an output for a @@ -181,9 +182,10 @@ class OutputProcessor: # 2) Detokenize the token ids into text and check for stop # strings. - stop_reason = req_state.detokenizer.update(new_token_ids) - if stop_reason: + stop_string = req_state.detokenizer.update(new_token_ids) + if stop_string and finish_reason != FinishReason.STOP: finish_reason = FinishReason.STOP + stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. @@ -250,7 +252,7 @@ class OutputProcessor: request_state: RequestState, new_token_ids: List[int], finish_reason: Optional[FinishReason], - stop_reason: Optional[str], + stop_reason: Union[int, str, None], ) -> Optional[RequestOutput]: finished = finish_reason is not None -- GitLab From 985b4a2b1989b2879809d6a3b84c11ac9e1171a3 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 12 Feb 2025 19:55:23 +0800 Subject: [PATCH 094/253] [Bugfix] Fix num video tokens calculation for Qwen2-VL (#13148) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/qwen2_vl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f2071eaff..d3294a4d4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -800,7 +800,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): preprocessed_size = ImageSize(width=image_width, height=image_height) - grid_t = max(num_frames // temporal_patch_size, 1) + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) grid_h = preprocessed_size.height // patch_size grid_w = preprocessed_size.width // patch_size -- GitLab From 314cfade02b28d50349c4df1a7ea0bbdaef589f1 Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Wed, 12 Feb 2025 11:29:56 -0500 Subject: [PATCH 095/253] [Frontend] Generate valid tool call IDs when using `tokenizer-mode=mistral` (#12332) --- tests/mistral_tool_use/__init__.py | 0 tests/mistral_tool_use/conftest.py | 40 +++++++++++++++++++ .../test_mistral_tool_calls.py | 29 ++++++++++++++ tests/mistral_tool_use/utils.py | 33 +++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 16 +++++--- .../tool_parsers/mistral_tool_parser.py | 2 +- .../transformers_utils/tokenizers/__init__.py | 7 +++- vllm/transformers_utils/tokenizers/mistral.py | 30 ++++++++++++++ 8 files changed, 149 insertions(+), 8 deletions(-) create mode 100644 tests/mistral_tool_use/__init__.py create mode 100644 tests/mistral_tool_use/conftest.py create mode 100644 tests/mistral_tool_use/test_mistral_tool_calls.py create mode 100644 tests/mistral_tool_use/utils.py diff --git a/tests/mistral_tool_use/__init__.py b/tests/mistral_tool_use/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mistral_tool_use/conftest.py b/tests/mistral_tool_use/conftest.py new file mode 100644 index 000000000..39ab01c9b --- /dev/null +++ b/tests/mistral_tool_use/conftest.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform + +from .utils import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + + if current_platform.is_rocm() and not config.get("supports_rocm", True): + pytest.skip("The {} model can't be tested on the ROCm platform".format( + config["model"])) + + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="session") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/mistral_tool_use/test_mistral_tool_calls.py b/tests/mistral_tool_use/test_mistral_tool_calls.py new file mode 100644 index 000000000..bbb3a0789 --- /dev/null +++ b/tests/mistral_tool_use/test_mistral_tool_calls.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest + +from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL + + +# test: a tool_choice with mistral-tokenizer results in an ID of length 9 +@pytest.mark.asyncio +async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice=WEATHER_TOOL, + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 1 + assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral diff --git a/tests/mistral_tool_use/utils.py b/tests/mistral_tool_use/utils.py new file mode 100644 index 000000000..971ed55ca --- /dev/null +++ b/tests/mistral_tool_use/utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional + +from typing_extensions import TypedDict + + +class ServerConfig(TypedDict, total=False): + model: str + arguments: List[str] + system_prompt: Optional[str] + supports_parallel: Optional[bool] + supports_rocm: Optional[bool] + + +ARGS: List[str] = ["--max-model-len", "1024"] + +CONFIGS: Dict[str, ServerConfig] = { + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tokenizer-mode", "mistral", + "--ignore-patterns=\"consolidated.safetensors\"" + ], + "system_prompt": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." + }, +} diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 107220d54..934bd2a95 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls +from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, + truncate_tool_call_ids) logger = init_logger(__name__) @@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing): return self.create_error_response( "tool_choice = \"required\" is not supported!") - # because of issues with pydantic we need to potentially - # re-serialize the tool_calls field of the request - # for more info: see comment in `maybe_serialize_tool_calls` if isinstance(tokenizer, MistralTokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` maybe_serialize_tool_calls(request) + truncate_tool_call_ids(request) if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) @@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing): elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall message = ChatMessage( role=role, content="", tool_calls=[ - ToolCall(function=FunctionCall( + tool_call_class(function=FunctionCall( name=request.tool_choice.function.name, arguments=output.text)) ]) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 51354f7c9..4f0480882 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -33,7 +33,7 @@ class MistralToolCall(ToolCall): @staticmethod def generate_random_id(): - # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # Mistral Tool Call Ids must be alphanumeric with a length of 9. # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 return "".join(choices(ALPHANUMERIC, k=9)) diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 2b64f3fc7..c12388d9b 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -from .mistral import MistralTokenizer, maybe_serialize_tool_calls +from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, + truncate_tool_call_ids) -__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"] +__all__ = [ + "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" +] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 59131a9d7..4e76f2dc8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): request.messages[i]["tool_calls"] = validated_tool_calls +def truncate_tool_call_ids(request: "ChatCompletionRequest"): + """Truncates tool call IDs for Mistral's ID requirements.""" + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls = message.get("tool_calls", []) + for tool_call in tool_calls: + if len(tool_call["id"]) > 9: + logger.warning( + "Truncating tool call ID: %s to %s", + tool_call["id"], + tool_call["id"][-9:], + ) + tool_call["id"] = tool_call["id"][-9:] + + request.messages[i]["tool_calls"] = tool_calls + + elif message.get("role") in {"tool_results", "tool"}: + if "tool_call_id" in message: + tool_call_id = message["tool_call_id"] + + if len(tool_call_id) > 9: + logger.warning( + "Truncating tool_call_id: %s to %s", + tool_call_id, + tool_call_id[-9:], + ) + tool_call_id = tool_call_id[-9:] + request.messages[i]["tool_call_id"] = tool_call_id + + def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE, -- GitLab From 82cabf53a32be91ec08f214e97de06b99d0eef18 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 13 Feb 2025 00:58:24 +0800 Subject: [PATCH 096/253] [Misc] Delete unused LoRA modules (#13151) --- tests/lora/test_lora_manager.py | 18 ++++++++++++------ vllm/lora/models.py | 8 +++++++- vllm/lora/punica_wrapper/punica_base.py | 2 +- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 6666f54fd..9fecd11f5 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): assert isinstance(model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA) + # Verify packed lora is correct + model_lora_clone = model_lora.clone(1) + model_lora_clone1 = model_lora1.clone(1) assert manager.add_adapter(model_lora) assert manager.add_adapter(model_lora1) + assert model_lora.get_lora("gate_proj") is None + assert model_lora.get_lora("up_proj") is None + assert model_lora1.get_lora("up_proj") is None packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) torch.testing.assert_close(packed_lora.lora_a[0], - model_lora.get_lora("gate_proj").lora_a) + model_lora_clone.get_lora("gate_proj").lora_a) torch.testing.assert_close(packed_lora.lora_b[0], - model_lora.get_lora("gate_proj").lora_b) + model_lora_clone.get_lora("gate_proj").lora_b) torch.testing.assert_close(packed_lora.lora_a[1], - model_lora.get_lora("up_proj").lora_a) + model_lora_clone.get_lora("up_proj").lora_a) torch.testing.assert_close(packed_lora.lora_b[1], - model_lora.get_lora("up_proj").lora_b) + model_lora_clone.get_lora("up_proj").lora_b) packed_lora1 = model_lora1.get_lora("gate_up_proj") assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) @@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None torch.testing.assert_close(packed_lora1.lora_a[1], - model_lora1.get_lora("up_proj").lora_a) + model_lora_clone1.get_lora("up_proj").lora_a) torch.testing.assert_close(packed_lora1.lora_b[1], - model_lora1.get_lora("up_proj").lora_b) + model_lora_clone1.get_lora("up_proj").lora_b) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index ef77fd4b7..b7403980d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -5,7 +5,8 @@ import math import os import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, + Union) import safetensors.torch import torch @@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager): def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): replacement_loras: List[Optional[LoRALayerWeights]] = [] + replaced_module: Set[str] = set() has_replacement = False for r in new_module_names: lora = lora_model.get_lora(r) replacement_loras.append(lora) if lora: has_replacement = True + replaced_module.add(r) if not has_replacement: continue for i in range(len(replacement_loras)): @@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager): replacement_loras[i] = None lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) + # Remove the modules that have been replaced. + for module in replaced_module: + lora_model.loras.pop(module, None) def deactivate_adapter(self, adapter_id: int) -> bool: return deactivate_adapter(adapter_id, self._active_adapters, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 1a2282ae9..dad98f8e2 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC): dtype=torch.long, device=device) - # 5 is the number of indicies tensors. + # 5 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices,long_lora_indices self.indices_len: List[Optional[int]] = [None] * 5 -- GitLab From 042c3419fad1a89c32a27abe8089af6de960bfce Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Wed, 12 Feb 2025 09:06:13 -0800 Subject: [PATCH 097/253] Introduce VLLM_CUDART_SO_PATH to allow users specify the .so path (#12998) Signed-off-by: Lu Fang --- .../device_communicators/cuda_wrapper.py | 32 ++++++++++++++++++- vllm/envs.py | 6 ++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 010caf7eb..bc2cfbf32 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -5,12 +5,14 @@ convenient for use when we just need to call a few functions. """ import ctypes +import glob from dataclasses import dataclass from typing import Any, Dict, List, Optional # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -60,6 +62,29 @@ def find_loaded_library(lib_name) -> Optional[str]: return path +def get_cudart_lib_path_from_env() -> Optional[str]: + """ + In some system, find_loaded_library() may not work. So we allow users to + specify the path through environment variable VLLM_CUDART_SO_PATH. + """ + cudart_so_env = envs.VLLM_CUDART_SO_PATH + if cudart_so_env is not None: + cudart_paths = [ + cudart_so_env, + ] + for path in cudart_paths: + file_paths = glob.glob(path) + if len(file_paths) > 0: + logger.info( + "Found cudart library at %s through env var" + "VLLM_CUDART_SO_PATH=%s", + file_paths[0], + cudart_so_env, + ) + return file_paths[0] + return None + + class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) @@ -105,8 +130,13 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: so_file = find_loaded_library("libcudart") + if so_file is None: + so_file = get_cudart_lib_path_from_env() assert so_file is not None, \ - "libcudart is not loaded in the current process" + ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib diff --git a/vllm/envs.py b/vllm/envs.py index 745b068b7..d99c794e6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -87,6 +87,7 @@ if TYPE_CHECKING: VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_CUDART_SO_PATH: Optional[str] = None def get_default_cache_root(): @@ -572,6 +573,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { # models the alignment is already naturally aligned to 256 bytes. "VLLM_CUDA_MEM_ALIGN_KV_CACHE": lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), + + # In some system, find_loaded_library() may not work. So we allow users to + # specify the path through environment variable VLLM_CUDART_SO_PATH. + "VLLM_CUDART_SO_PATH": + lambda: os.getenv("VLLM_CUDART_SO_PATH", None), } # end-env-vars-definition -- GitLab From 2c2b560f4829b9dfc91308628c5d6f6928247a0e Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 12:12:22 -0500 Subject: [PATCH 098/253] [CI/Build] Use mypy matcher for pre-commit CI job (#13162) Signed-off-by: Russell Bryant --- .github/workflows/pre-commit.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index dc10b9116..6ab63a402 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -14,6 +14,7 @@ jobs: with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" + - run: echo "::add-matcher::.github/workflows/matchers/mypy.json" - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --all-files --hook-stage manual -- GitLab From 36a08630e80b6176489eef45e15b07724d95a944 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Thu, 13 Feb 2025 01:19:43 +0800 Subject: [PATCH 099/253] [CORE] [QUANT] Support for GPTQModel's `dynamic` quantization per module override/control (#7086) --- tests/quantization/test_gptq_dynamic.py | 68 ++++++++++++++ tests/quantization/test_lm_head.py | 25 +++-- vllm/lora/layers.py | 2 +- .../model_executor/layers/logits_processor.py | 6 +- .../layers/quantization/gptq.py | 47 ++++++++-- .../layers/quantization/gptq_marlin.py | 59 +++++++++--- .../layers/quantization/utils/gptq_utils.py | 94 +++++++++++++++++++ .../layers/vocab_parallel_embedding.py | 36 +++---- 8 files changed, 281 insertions(+), 56 deletions(-) create mode 100644 tests/quantization/test_gptq_dynamic.py create mode 100644 vllm/model_executor/layers/quantization/utils/gptq_utils.py diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py new file mode 100644 index 000000000..c6f34fef2 --- /dev/null +++ b/tests/quantization/test_gptq_dynamic.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether gptq models with dynamic quantized can be loaded. + +Run `pytest tests/quantization/test_gptq_dynamic.py --forked`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_dynamic_override) + +PROMPT = "On the surface of Mars, we found" + +# The first layer is quantized using bits=4, group_size=128 +# The second layer is quantized using bits=8, group_size=32 +# All other layers (layer index >= 2) are not quantized +MODEL_QUANT = [ + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", + True), + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + False), +] + + +@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) +def test_gptq_with_dynamic(vllm_runner, model_id: str, + use_marlin_kernel: bool): + + vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) + + linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( + GPTQLinearMethod) + + for name, submodule in (vllm_model.model.llm_engine.model_executor. + driver_worker.model_runner.model.named_modules()): + if name == "lm_head": + assert isinstance(submodule.quant_method, linear_method_cls) + elif name == 'model.layers.0.self_attn.qkv_proj': + # The first layer is quantized using bits=4, group_size=128 + # desc_act=True + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert config.weight_bits == 4 + assert config.group_size == 128 + assert config.desc_act + elif name == 'model.layers.1.self_attn.qkv_proj': + # The second layer is quantized using bits=8, group_size=32 + # desc_act=False + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert get_dynamic_override(config, layer_name=name, + key="bits") == 8 + assert get_dynamic_override(config, + layer_name=name, + key="group_size") == 32 + assert not get_dynamic_override( + config, layer_name=name, key="desc_act") + elif (name == 'model.layers.2.self_attn.qkv_proj' + or name == 'model.layers.2.mlp.gate_up_proj'): + # All other layers (layer index >= 2) are not quantized + assert isinstance(submodule.quant_method, UnquantizedLinearMethod) + + del vllm_model diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index ec60d8a57..20435a287 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -3,7 +3,6 @@ Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`. """ -from typing import Tuple import pytest import torch @@ -17,31 +16,31 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( PROMPT = "On the surface of Mars, we found" -MODELS_QUANT = [( - "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse", - True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), - ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)] +MODELS_QUANT = [ + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True), + ("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False), + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False) +] -@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT) +@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT) def test_lm_head( vllm_runner, - model_lm_head_quant: Tuple[str, bool], + model_id: str, + lm_head_quantized: bool, ) -> None: - model, lm_head_quantized = model_lm_head_quant - - with vllm_runner(model, dtype=torch.float16, + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: def check_model(model): lm_head_layer = model.lm_head - if lm_head_quantized: - assert isinstance(lm_head_layer.linear_method, + assert isinstance(lm_head_layer.quant_method, (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) else: - assert isinstance(lm_head_layer.linear_method, + assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) vllm_model.apply_model(check_model) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9826aeb9d..7f68dae97 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1039,7 +1039,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.linear_method.apply(lm_head, hidden_states) + logits = lm_head.quant_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 0565c6e8b..9b1742998 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -108,9 +108,9 @@ class LogitsProcessor(nn.Module): embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.linear_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) # Gather logits for TP logits = self._gather_logits(logits) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0cb77a754..6d1f0cc2e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -3,16 +3,17 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -32,7 +33,33 @@ class GPTQConfig(QuantizationConfig): group_size: int, desc_act: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act @@ -47,7 +74,8 @@ class GPTQConfig(QuantizationConfig): return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act})," - f"lm_head_quantized={self.lm_head_quantized}") + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") @classmethod def get_name(cls) -> str: @@ -68,19 +96,20 @@ class GPTQConfig(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(weight_bits, group_size, desc_act, lm_head_quantized) + return cls(weight_bits, group_size, desc_act, lm_head_quantized, + dynamic) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["GPTQLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return GPTQLinearMethod(self) - return None + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) class ExllamaState(Enum): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 84c53b2c1..0a9d86b00 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -9,17 +9,21 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, + UnquantizedLinearMethod, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -47,12 +51,41 @@ class GPTQMarlinConfig(QuantizationConfig): desc_act: bool, is_sym: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) desc_act = False + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.is_sym = is_sym + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act @@ -68,7 +101,8 @@ class GPTQMarlinConfig(QuantizationConfig): return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " - f"lm_head_quantized={self.lm_head_quantized})") + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") @classmethod def get_name(cls) -> str: @@ -88,6 +122,9 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) @@ -95,7 +132,7 @@ class GPTQMarlinConfig(QuantizationConfig): lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized) + lm_head_quantized, dynamic) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -120,17 +157,15 @@ class GPTQMarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): - return GPTQMarlinLinearMethod(self) - elif isinstance(layer, FusedMoE): + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod", + UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]: + if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) - return None + return get_linear_quant_method(self, layer, prefix, + GPTQMarlinLinearMethod) @classmethod def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): - # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") @@ -143,7 +178,7 @@ class GPTQMarlinConfig(QuantizationConfig): if quant_method != "gptq": return False - # If we cannot find the info needed in the config, cannot convert. + # Marlin conversion is only valid if required properties are found if (num_bits is None or group_size is None or sym is None or desc_act is None): return False diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py new file mode 100644 index 000000000..5b0e6299f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +import re +from copy import deepcopy +from typing import Dict, Optional, Union + +import torch + +from vllm.config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, UnquantizedEmbeddingMethod) + + +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", + config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", + config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", + config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}") + + config.quant_type = config.TYPE_MAP[(config.weight_bits, + config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits.") + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, + None] = None) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + cloned_config = deepcopy(config) + parallel_lm_head_quantized = isinstance( + layer, ParallelLMHead) and cloned_config.lm_head_quantized + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index e409094dd..f65dfc3cb 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module): self.tp_size) self.embedding_dim = embedding_dim - linear_method = None + quant_method = None if quant_config is not None: - linear_method = quant_config.get_quant_method(self, prefix=prefix) - if linear_method is None: - linear_method = UnquantizedEmbeddingMethod() + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() # If we are making an embedding layer, then our quantization linear # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self.__class__) is VocabParallelEmbedding - linear_method_implements_embedding = method_has_implemented_embedding( - type(linear_method)) - if is_embedding_layer and not linear_method_implements_embedding: + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( - f"The class {type(linear_method).__name__} must implement " + f"The class {type(quant_method).__name__} must implement " "the 'embedding' method, see UnquantizedEmbeddingMethod.") - self.linear_method: QuantizeMethodBase = linear_method + self.quant_method: QuantizeMethodBase = quant_method if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -260,13 +260,13 @@ class VocabParallelEmbedding(torch.nn.Module): self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index) - self.linear_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, @@ -412,8 +412,8 @@ class VocabParallelEmbedding(torch.nn.Module): else: masked_input = input_ # Get the embeddings. - output_parallel = self.linear_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, + masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) -- GitLab From 09972e716c4a90bfd4385540c9f478e18b4efb2d Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 12 Feb 2025 12:19:53 -0500 Subject: [PATCH 100/253] [Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity (#13119) --- vllm/model_executor/layers/linear.py | 35 ++++++++++--------- .../layers/quantization/awq_marlin.py | 28 +++++++++------ .../layers/quantization/moe_wna16.py | 9 +++-- .../layers/quantization/utils/marlin_utils.py | 15 ++++++++ 4 files changed, 58 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dad161120..521724765 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -290,29 +290,30 @@ class ColumnParallelLinear(LinearBase): quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[list[int]] = None, prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) - - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() - assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, tp_size) + divide(output_size, self.tp_size) for output_size in self.output_sizes ] + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) + + self.gather_output = gather_output + if output_sizes is None: output_sizes = [output_size] + assert self.quant_method is not None self.quant_method.create_weights( layer=self, - input_size_per_partition=self.input_size, + input_size_per_partition=self.input_size_per_partition, output_partition_sizes=self.output_partition_sizes, input_size=self.input_size, output_size=self.output_size, @@ -1044,22 +1045,24 @@ class RowParallelLinear(LinearBase): reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - # Divide the weight matrix along the last dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None - self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=[self.output_size], + output_partition_sizes=self.output_partition_sizes, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 8849ba292..a43b2e597 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) -from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, moe_awq_to_marlin_zero_points, - verify_marlin_supported, verify_marlin_supports_shape) + check_marlin_supports_layer, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, + moe_awq_to_marlin_zero_points, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, - weight_bits: int, - group_size: int, - zero_point: bool, + def __init__(self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[List[str]] = None) -> None: + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.weight_bits = weight_bits self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config if self.weight_bits not in self.TYPE_MAP: raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " @@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig): modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None) return cls(weight_bits, group_size, zero_point, lm_head_quantized, - modules_to_not_convert) + modules_to_not_convert, config) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMarlin. " + "Falling back to unoptimized AWQ kernels.") + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 56fa597e2..b9460e7d7 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -87,8 +89,8 @@ class MoeWNA16Config(QuantizationConfig): modules_to_not_convert = [] elif linear_quant_method == "awq": has_zp = cls.get_from_keys(config, ["zero_point"]) - modules_to_not_convert = cls.get_from_keys( - config, ["modules_to_not_convert"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) else: raise ValueError("moe_wna16 only support gptq and awq.") @@ -135,7 +137,8 @@ class MoeWNA16Config(QuantizationConfig): return GPTQConfig.from_config( self.full_config).get_quant_method(layer, prefix) elif self.linear_quant_method == "awq": - if self.use_marlin: + if self.use_marlin and check_marlin_supports_layer( + layer, self.group_size): return AWQMarlinConfig.from_config( self.full_config).get_quant_method(layer, prefix) else: diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 3beba3083..05e37251a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -6,6 +6,7 @@ import numpy import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int, return True, None +def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + output_size_per_partition = getattr(layer, "output_size_per_partition", + None) or layer.output_size + input_size_per_partition = getattr(layer, "input_size_per_partition", + None) or layer.input_size + + return check_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=layer.input_size, + group_size=group_size)[0] + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // -- GitLab From 14b7899d10217b31d98ba78ade7768fbd735ca4e Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 12 Feb 2025 14:16:06 -0500 Subject: [PATCH 101/253] [CI] Fix failing FP8 cpu offload test (#13170) Signed-off-by: mgoin --- tests/quantization/test_cpu_offload.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index 29a5721ef..de03d37a7 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -1,5 +1,5 @@ -# SPDX-License-Identifier: Apache-2.0 - +# SPDX-License-Identifier: Apache-2.0 + # Expanded quantized model tests for CPU offloading # Base tests: tests/basic_correctness/test_cpu_offload.py @@ -14,13 +14,13 @@ from ..utils import compare_two_settings reason="fp8 is not supported on this GPU type.") def test_cpu_offload_fp8(): # Test quantization of an unquantized checkpoint - compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct", + compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", ["--quantization", "fp8"], - ["--quantization", "fp8", "--cpu-offload-gb", "2"], + ["--quantization", "fp8", "--cpu-offload-gb", "1"], max_wait_seconds=480) # Test loading a quantized checkpoint - compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [], - ["--cpu-offload-gb", "2"], + compare_two_settings("neuralmagic/Qwen2-1.5B-Instruct-FP8", [], + ["--cpu-offload-gb", "1"], max_wait_seconds=480) -- GitLab From 4c0d93f4b2de241336f4732cb5799cee8fedcb52 Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:58:11 -0800 Subject: [PATCH 102/253] [V1][Bugfix] Copy encoder input ids to fix set iteration issue during VLM abort (#13173) Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/v1/core/encoder_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 651bc01aa..13ad14e45 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -54,7 +54,7 @@ class EncoderCacheManager: def free(self, request: Request) -> None: """Free all cached input ids for the request.""" - input_ids = self.get_cached_input_ids(request) + input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) -- GitLab From 8eafe5eaeadfbe91274a0a0915ee6903990e3fed Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 12 Feb 2025 22:48:31 -0500 Subject: [PATCH 103/253] [CI/Build] Ignore ruff warning up007 (#13182) Signed-off-by: Russell Bryant --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9892967b8..849e8781e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,8 @@ ignore = [ "UP032", # Python 3.8 typing "UP006", "UP035", - + # Can remove once 3.10+ is the minimum Python version + "UP007", ] [tool.mypy] -- GitLab From 9f9704dca6dda7d4af556b133d5e42c360dd2fb0 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Wed, 12 Feb 2025 19:51:33 -0800 Subject: [PATCH 104/253] [perf-benchmark] cleanup unused Docker images and volumes in H100 benchmark instance (#12706) --- .buildkite/nightly-benchmarks/benchmark-pipeline.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 679abf181..df95e46d6 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -70,6 +70,12 @@ steps: #key: block-h100 #depends_on: ~ + - label: "Cleanup H100" + agents: + queue: H100 + depends_on: ~ + command: docker system prune -a --volumes --force + - label: "H100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: -- GitLab From 4fc5c23bb64719bda4f2b24e0b95637ccd530fab Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 12 Feb 2025 19:51:51 -0800 Subject: [PATCH 105/253] [NVIDIA] Support nvfp4 quantization (#12784) --- CMakeLists.txt | 18 + cmake/utils.cmake | 18 +- csrc/cuda_utils.h | 12 + csrc/cuda_utils_kernels.cu | 22 +- csrc/ops.h | 4 + csrc/quantization/fp4/nvfp4_quant_entry.cu | 32 ++ csrc/quantization/fp4/nvfp4_quant_kernels.cu | 379 +++++++++++++++++++ csrc/torch_bindings.cpp | 6 + tests/kernels/test_nvfp4_quant.py | 149 ++++++++ tests/test_scalartype.py | 1 + vllm/_custom_ops.py | 57 +++ vllm/scalar_type.py | 3 + 12 files changed, 688 insertions(+), 13 deletions(-) create mode 100644 csrc/quantization/fp4/nvfp4_quant_entry.cu create mode 100644 csrc/quantization/fp4/nvfp4_quant_kernels.cu create mode 100644 tests/kernels/test_nvfp4_quant.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a0fd346c6..244ceb721 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,6 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_compressor_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -377,6 +378,23 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # FP4 Archs and flags + cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() # # Machete kernels diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1c1c53981..c9cd099b8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -257,9 +257,9 @@ endmacro() # where `<=` is the version comparison operator. # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. -# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is -# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add -# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS). +# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is +# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -272,8 +272,8 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) - # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should - # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS + # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) if ("9.0a" IN_LIST SRC_CUDA_ARCHS) list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") @@ -283,6 +283,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endif() endif() + if ("10.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + set(_CUDA_ARCHS "10.0a") + endif() + endif() + list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index c35224218..6f79d2b74 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,7 @@ #pragma once +#include + #if defined(__CUDACC__) || defined(_NVHPC_CUDA) #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ #define DEVICE_INLINE __forceinline__ __device__ @@ -10,6 +12,16 @@ #define HOST_INLINE inline #endif +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index d6f9eb646..0627a4267 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,16 +1,22 @@ +#include "cuda_utils.h" #ifdef USE_ROCM #include #include #endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id) { - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), - device); + // Return the cached value on subsequent calls + static int value = [=]() { + int device = static_cast(device_id); + if (device < 0) { + CUDA_CHECK(cudaGetDevice(&device)); + } + int value; + CUDA_CHECK(cudaDeviceGetAttribute( + &value, static_cast(attribute), device)); + return static_cast(value); + }(); + return value; } diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef31..70e864cc6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -195,6 +195,10 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_scale, + torch::Tensor const& input_scale); + void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu new file mode 100644 index 000000000..b1426c43b --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_quant_sm100a(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf); +#endif + +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_sf, torch::Tensor const& input_sf) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu new file mode 100644 index 000000000..c3b8e9b3e --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include "cuda_utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, + int64_t* output, int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + if (useUE8M0) { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), + reinterpret_cast(SFOuput)); + } else { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), + reinterpret_cast(SFOuput)); + } +} + +// Instantiate the function. +template void invokeFP4Quantization(int m, int n, half const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +void scaled_fp4_quant_sm100a(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + auto stream = at::cuda::getStreamFromPool(false, input.get_device()); + if (stream == nullptr) { + std::cerr << "Warning: Null CUDA stream" << std::endl; + } + + // We don't support e8m0 scales at this moment. + bool useUE8M0 = false; + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, + useUE8M0, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, + useUE8M0, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() + << " for the input datatype which is invalid"; + throw std::runtime_error( + "Unsupported input data type for quantize_to_fp4."); + } + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c03806f43..784ded262 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -423,6 +423,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); + // Compute NVFP4 block quantized tensor. + ops.def( + "scaled_fp4_quant(Tensor! output, Tensor input," + " Tensor! output_scale, Tensor input_scale) -> ()"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," diff --git a/tests/kernels/test_nvfp4_quant.py b/tests/kernels/test_nvfp4_quant.py new file mode 100644 index 000000000..93735fc09 --- /dev/null +++ b/tests/kernels/test_nvfp4_quant.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), + (90, 128), (150, 128), (150, 48), (90, 80)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +# E2M1 to float +# 0111 -> 6 +# 0110 -> 4 +# 0101 -> 3 +# 0100 -> 2 +# 0011 -> 1.5 +# 0010 -> 1 +# 0001 -> 0.5 +# 0000 -> 0 +E2M1_TO_FLOAT32 = [ + 0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6. +] +BLOCK_SIZE = 16 + + +def cast_from_fp4(x, m, n): + # The fp4 values are packed in uint8 as [v_1st | v_2nd] + v_2nd = x & 0xF + v_1st = (x >> 4) & 0xF + c = torch.stack((v_2nd, v_1st), dim=-1) + out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()]) + out = out.reshape(m, n).to(torch.float32) + return out + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def recover_swizzled_scales(scale, m, n): + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // BLOCK_SIZE + rounded_n = round_up(scale_n, 4) + # Recover the swizzled scaling factor to linear layout + tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32) + return result[:m, :scale_n] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = ops.scaled_fp4_quant(x, global_scale) + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.parametrize("pad_shape", PAD_SHAPES) +@torch.inference_mode() +def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: + dtype = torch.float16 + current_platform.seed_everything(42) + torch.set_default_device('cuda:0') + + m, n = pad_shape + + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = ops.scaled_fp4_quant(x, global_scale) + + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index 6e36f2c33..d0e57ea86 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -11,6 +11,7 @@ from vllm.scalar_type import scalar_types (0, 15, scalar_types.uint4), (-8, 7, scalar_types.uint4b8), (-128, 127, scalar_types.uint8b128), + (-6., 6., scalar_types.float4_e2m1fn), (-28., 28., scalar_types.float6_e3m2f), (torch.int8, scalar_types.int8), (torch.uint8, scalar_types.uint8), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a68235016..67843c177 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops._C.permute_cols(a, perm) +# fp4 +def scaled_fp4_quant( + input: torch.Tensor, + input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in the sizzled layout. + """ + assert input.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {input.ndim}.') + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert input.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=device, + dtype=torch.int32) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, + input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 9f6e85920..1d7675dda 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -321,6 +321,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) -- GitLab From d88c8666a167c97193cb3d48fd2de3d90b5b8d85 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 12 Feb 2025 22:52:11 -0500 Subject: [PATCH 106/253] [Bugfix][Example] Fix GCed profiling server for TPU (#12792) Signed-off-by: mgoin --- examples/offline_inference/profiling_tpu/profiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py index b1fe829b3..d54117d62 100644 --- a/examples/offline_inference/profiling_tpu/profiling.py +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -24,7 +24,7 @@ def main(args: argparse.Namespace): engine_args = EngineArgs.from_cli_args(args) llm = LLM(**dataclasses.asdict(engine_args)) - _ = xp.start_server(9012) + server = xp.start_server(9012) # noqa: F841 sampling_params = SamplingParams( temperature=0.0, -- GitLab From bc55d13070a9f3f2b8524901817d406cb8a37a3b Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 13 Feb 2025 12:26:21 +0800 Subject: [PATCH 107/253] [VLM] Implement merged multimodal processor for Mllama (#11427) --- .../vision_language/test_mllama.py | 71 ++- .../multimodal/processing/test_common.py | 13 +- vllm/inputs/preprocess.py | 90 +++- vllm/inputs/registry.py | 3 +- vllm/model_executor/models/mllama.py | 408 +++++++++--------- vllm/multimodal/inputs.py | 16 + vllm/multimodal/processing.py | 60 ++- vllm/multimodal/profiling.py | 28 +- 8 files changed, 456 insertions(+), 233 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 4cd2dbdb4..202516f4c 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -7,11 +7,11 @@ import torch from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) +from vllm import LLM, SamplingParams from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) -from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID, - MllamaForConditionalGeneration) +from vllm.model_executor.models.mllama import MllamaForConditionalGeneration from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs @@ -21,6 +21,7 @@ from ....utils import large_gpu_test from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 3 +MLLAMA_IMAGE_TOKEN_ID = 128256 LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] @@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ) +@large_gpu_test(min_gb=48) +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +def test_explicit_implicit_prompt( + image_assets: _ImageAssets, + model: str, + dtype: str, + max_tokens: int, +): + stop_sign = image_assets[0].pil_image + # yapf: disable + prompts = [ + # explicit prompt + { + "encoder_prompt": { + "prompt": "<|image|>", + "multi_modal_data": {"image": stop_sign}, + }, + "decoder_prompt": { + "prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501 + } + }, + { + "encoder_prompt": "Not <|image|>", + "decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 + }, + # implicit prompt + { + "prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 + "multi_modal_data": {"image": stop_sign}, + }, + { + "prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 + }, + ] + # yapf: enable + llm = LLM( + model=model, + dtype=dtype, + max_model_len=4096, + max_num_seqs=2, + tensor_parallel_size=1, + enforce_eager=True, + ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=max_tokens, + ) + outputs = llm.generate(prompts, sampling_params) + n_prompts = len(prompts) + explicit_outputs = outputs[:n_prompts // 2] + implicit_outputs = outputs[n_prompts // 2:] + for exp_output, imp_output in zip(explicit_outputs, implicit_outputs): + assert exp_output.outputs[0].text == imp_output.outputs[0].text + + @large_gpu_test(min_gb=48) @pytest.mark.core_model @pytest.mark.parametrize("model", models) @@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, images=images) +class DummyModel: + image_token_id = MLLAMA_IMAGE_TOKEN_ID + + @pytest.mark.core_model @pytest.mark.parametrize( "input_indices_and_output", @@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None: use_cuda_graph=False, ) - dummy: dict[str, str] = {} + dummy = DummyModel() cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ .get_cross_attention_mask(dummy, @@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: use_cuda_graph=False, ) - dummy: dict[str, str] = {} + dummy = DummyModel() full_text_row_masked_out_mask = MllamaForConditionalGeneration\ .get_full_text_row_masked_out_mask(dummy, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 6244056c7..67ef8b17a 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -85,6 +85,14 @@ def _test_processing_correctness( partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), } + tokenizer_encode_kwargs = {} + if model_config.hf_config.model_type == "mllama": + # For Mllama, tokenizer will always add bos_token at the beginning of + # prompt by default, causing hf_processor outputs incorrect token ids. + # So we need use `add_special_tokens=False` here to leave bos_token + # to be added by the processor. + tokenizer_encode_kwargs = {"add_special_tokens": False} + for batch_idx in range(num_batches): mm_data = { k: @@ -122,7 +130,7 @@ def _test_processing_correctness( f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") baseline_tokenized_result = baseline_processor.apply( - tokenizer.encode(prompt), + tokenizer.encode(prompt, **tokenizer_encode_kwargs), mm_data=mm_data, hf_processor_mm_kwargs={}, ) @@ -131,7 +139,7 @@ def _test_processing_correctness( f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") cached_tokenized_result = cached_processor.apply( - tokenizer.encode(prompt), + tokenizer.encode(prompt, **tokenizer_encode_kwargs), mm_data=mm_data, hf_processor_mm_kwargs={}, ) @@ -155,6 +163,7 @@ def _test_processing_correctness( "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", "mistral-community/pixtral-12b", "openbmb/MiniCPM-o-2_6", diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 656f2f2b7..bc5856990 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Tuple, Union, cast from typing_extensions import assert_never @@ -9,7 +9,8 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -495,6 +496,51 @@ class InputPreprocessor: decoder=decoder_inputs, ) + def _separate_enc_dec_inputs_from_mm_processor_outputs( + self, + inputs: SingletonInputs, + decoder_inputs_to_override: Optional[SingletonInputs] = None, + ) -> Tuple[SingletonInputs, SingletonInputs]: + """ + For encoder/decoder models only: + Separate Encoder/Decoder inputs from a MultiModalEncDecInputs + """ + encoder_inputs: SingletonInputs + decoder_inputs: SingletonInputs + if inputs["type"] == "multimodal": + # Multimodal data inputs + assert ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs) + inputs = cast(MultiModalEncDecInputs, inputs) + encoder_inputs = token_inputs( + prompt=inputs["encoder_prompt"], + prompt_token_ids=inputs["encoder_prompt_token_ids"], + ) + if decoder_inputs_to_override is not None: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_inputs_to_override.get("prompt", ""), + prompt_token_ids=decoder_inputs_to_override[ + "prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_placeholders=inputs["mm_placeholders"], + ) + else: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=inputs["prompt"], + prompt_token_ids=inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_placeholders=inputs["mm_placeholders"], + ) + elif inputs["type"] == "token": + # Text-only inputs + encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + decoder_inputs = decoder_inputs_to_override or inputs + else: + assert_never(inputs) # type: ignore[arg-type] + return encoder_inputs, decoder_inputs + def _process_encoder_decoder_prompt( self, prompt: PromptType, @@ -539,7 +585,6 @@ class InputPreprocessor: prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None else: @@ -547,13 +592,28 @@ class InputPreprocessor: decoder_input, request_id=request_id, ) + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) else: - encoder_inputs = self._prompt_to_llm_inputs( + inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs - decoder_inputs = None + decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) @@ -583,13 +643,29 @@ class InputPreprocessor: encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) + + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) else: - encoder_inputs = await self._prompt_to_llm_inputs_async( + inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs - decoder_inputs = None + decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index cd4214439..87b7a7631 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -350,7 +350,8 @@ class InputRegistry: ) processor = mm_registry.create_processor(model_config, tokenizer) profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_dummy_data(seq_len) + dummy_data = profiler.get_dummy_data( + seq_len, is_encoder_data=is_encoder_data) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d1cb04cdb..3ca22d346 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -23,14 +23,15 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama -from PIL import Image +from PIL.Image import Image from torch import nn +from transformers import BatchFeature, MllamaConfig from transformers.modeling_outputs import (BaseModelOutput, CausalLMOutputWithPast) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) from transformers.models.mllama.processing_mllama import ( - get_cross_attention_token_mask) + MllamaProcessor, get_cross_attention_token_mask) import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType @@ -38,8 +39,6 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, - InputContext, TokenInputs, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -54,8 +53,13 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SequenceData -from vllm.utils import is_list_of +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataDict, MultiModalDataItems) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -63,8 +67,6 @@ from .llama import LlamaDecoderLayer, LlamaMLP from .utils import maybe_prefix logger = init_logger(__name__) -MLLAMA_IMAGE_TOKEN_ID = 128256 -MLLAMA_IMAGE_TOKEN = "<|image|>" class MllamaImagePixelInputs(TypedDict): @@ -81,158 +83,191 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs -def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: - num_images = 0 - for token_id in prompt_token_ids[::-1]: - if token_id == MLLAMA_IMAGE_TOKEN_ID: - num_images += 1 - elif num_images > 0: - break - return num_images - - -def input_processor_for_mllama( - ctx: InputContext, - inputs: EncoderDecoderInputs, -) -> EncoderDecoderInputs: - # Example input to processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000], - # }, - # } - - # move encoder prompt to decoder - dec_inputs = TokenInputs(**inputs["encoder"]) - - multi_modal_data = dec_inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - # text-only - return EncoderDecoderInputs( - encoder=token_inputs([]), - decoder=dec_inputs, +def calc_token_per_chunk(image_size: int) -> int: + assert image_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (image_size // 14)**2 + 1 + return token_per_chunk + + +class MllamaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> MllamaConfig: + return self.ctx.get_hf_config(MllamaConfig) + + def get_hf_processor(self) -> MllamaProcessor: + return self.ctx.get_hf_processor(MllamaProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_token_per_chunk_from_config(self) -> int: + image_size = self.get_hf_config().vision_config.image_size + return calc_token_per_chunk(image_size) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + vision_config = self.get_hf_config().vision_config + token_per_chunk = self.get_token_per_chunk_from_config() + mm_max_tokens = vision_config.max_num_tiles * token_per_chunk + return {"image": mm_max_tokens} + + def get_num_tiles_per_image(self, image_height: int, + image_width: int) -> int: + vision_config = self.get_hf_config().vision_config + max_num_tiles = vision_config.max_num_tiles + image_size = vision_config.image_size + tiled_height, tiled_width = get_optimal_tiled_canvas( + image_height, + image_width, + max_num_tiles, + tile_size=image_size, + ) + num_tiles_height = tiled_height // image_size + num_tiles_width = tiled_width // image_size + return num_tiles_height * num_tiles_width + + def get_image_size_with_most_features(self) -> ImageSize: + vision_config = self.get_hf_config().vision_config + image_size = vision_config.image_size + max_num_tiles = vision_config.max_num_tiles + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=max_num_tiles * image_size, width=image_size) + + +class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, ) - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_data = [image_data] - - assert is_list_of(image_data, Image.Image) - - num_image_tokens = dec_inputs['prompt_token_ids'].count( - MLLAMA_IMAGE_TOKEN_ID) - if num_image_tokens != len(image_data): - raise ValueError( - f"The number of image tokens ({num_image_tokens}) must be" - f" the same as the number of images ({len(image_data)})") - - # Since only the last group of consecutive images - # are attended by the decoded tokens, we only need to - # get the number of tiles for those images. - num_decode_images = _get_num_image_in_last_group( - dec_inputs["prompt_token_ids"]) - - hf_config = ctx.model_config.hf_config - vision_config = hf_config.vision_config - - num_tiles = 0 - for image in image_data[::-1]: - width, height = image.size - tile_size = vision_config.image_size - canvas_height, canvas_width = get_optimal_tiled_canvas( - image_height=height, - image_width=width, - max_image_tiles=vision_config.max_num_tiles, - tile_size=tile_size, + +class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] + ): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + if mm_data: + num_tiles = [ + self.info.get_num_tiles_per_image(img.height, img.width) + for img in mm_data["images"] + ] + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + processed_outputs["num_tiles"] = torch.tensor(num_tiles) + for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): + processed_outputs[k] = processed_outputs[k].squeeze(0) + # Example input to encoder and decoder: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } + processed_token_ids = processed_outputs.pop("input_ids") + start_idx, end_idx = 0, processed_token_ids.size(1) + processed_prompt_text = tokenizer.decode(processed_token_ids[0]) + + hf_processor = self.info.get_hf_processor() + bos_token = hf_processor.bos_token + # Remove the bos_token from the start of prompt, + # because we all know there would be image_token. + if processed_prompt_text.startswith(bos_token): + start_idx += 1 + # Remove the bos_token from the end of prompt, + # because text is empty in this case. + if processed_prompt_text.endswith(bos_token): + end_idx -= 1 + processed_outputs[ + "input_ids"] = processed_token_ids[:, start_idx:end_idx] + else: + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + aspect_ratio_ids=MultiModalFieldConfig.batched("image"), + aspect_ratio_mask=MultiModalFieldConfig.batched("image"), + num_tiles=MultiModalFieldConfig.batched("image"), ) - num_tiles_height = canvas_height // tile_size - num_tiles_width = canvas_width // tile_size - num_tiles += num_tiles_height * num_tiles_width - num_decode_images -= 1 - if num_decode_images == 0: - break - - # Set encoder prompt length based on the number of tiles. - # This tells the block manager to allocate correct number - # of slots for encoder tokens. - assert vision_config.image_size % 14 == 0, \ - "chunk size should be multiple of 14" - token_per_chunk = (vision_config.image_size // 14)**2 + 1 - num_tokens = num_tiles * token_per_chunk - - # Example output from processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128256, ..., 128256], - # 'prompt': '<|image|><|image|>...<|image|>', - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # } - return EncoderDecoderInputs( - encoder=token_inputs( - prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, - prompt=MLLAMA_IMAGE_TOKEN * num_tokens, - multi_modal_data=multi_modal_data, - ), - decoder=dec_inputs, - ) - - -def get_max_mllama_image_tokens(ctx: InputContext) -> int: - hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 - return hf_config.vision_config.max_num_tiles * token_per_chunk - - -def dummy_decoder_seq_data(seq_len: int, num_images: int): - # <|image|> * num_images + 0 * (seq_len - num_images) - assert seq_len >= num_images, \ - "seq_len should be greater than or equal to num_images" - - return SequenceData.from_prompt_token_counts( - (MLLAMA_IMAGE_TOKEN_ID, num_images), - (0, seq_len - num_images), - ) - - -def dummy_encoder_seq_data(ctx: InputContext, num_images: int): - num_tokens = get_max_mllama_image_tokens(ctx) * num_images - - return SequenceData.from_prompt_token_counts( - (MLLAMA_IMAGE_TOKEN_ID, num_tokens)) - - -def dummy_image(num_images: int, ): - width = height = 1024 - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - return DummyData(dummy_decoder_seq_data(seq_len, num_images)) - - -def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - return DummyData(dummy_encoder_seq_data(ctx, num_images), - dummy_image(num_images)) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + data = mm_data.get("image", []) + num_images = 1 if isinstance(data, Image) else len(data) + image_token_id = self.info.get_hf_config().image_token_index + return [image_token_id] * num_images + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + token_per_chunk = self.info.get_token_per_chunk_from_config() + image_token_id = self.info.get_hf_config().image_token_index + + def get_replacement_mllama(item_idx): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + num_tile = self.info.get_num_tiles_per_image( + image_height=image_size.height, + image_width=image_size.width, + ) + num_tokens = num_tile * token_per_chunk + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_mllama, + ) + ] def _prepare_aspect_ratio_attention_mask( @@ -1107,11 +1142,9 @@ class MllamaForCausalLM(nn.Module): return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) -@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) -@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, + info=MllamaProcessingInfo, + dummy_inputs=MllamaDummyInputsBuilder) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: MllamaConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.quant_config = quant_config self.vocab_size = config.text_config.vocab_size @@ -1130,6 +1163,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): self.pad_token_id = \ config.pad_token_id if config.pad_token_id is not None else -1 self.image_size = config.vision_config.image_size + self.image_token_id = config.image_token_index self.vision_model = MllamaVisionModel(config.vision_config, quant_config, @@ -1204,48 +1238,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): if pixel_values is not None: assert aspect_ratio_ids is not None assert aspect_ratio_mask is not None - max_num_images = max([len(x[0]) for x in pixel_values]) - if max_num_images == 0: - raise ValueError("No images provided.") - max_num_tiles = max( - max([len(x) for x in y[0]]) for y in pixel_values) - device = next(self.multi_modal_projector.parameters()).device - bsz = len(pixel_values) - out_num_tiles = [] - out_images = torch.zeros( - bsz, - max_num_images, - max_num_tiles, - 3, - self.image_size, - self.image_size, - dtype=torch.float32, - device=device, - ) - out_ar_ids = torch.ones(bsz, - max_num_images, - dtype=torch.int64, - device=device) - out_ar_mask = torch.zeros(bsz, - max_num_images, - max_num_tiles, - dtype=torch.int64, - device=device) - for b in range(len(pixel_values)): - _num_tiles = [] - for i in range(len(pixel_values[b][0])): - img = pixel_values[b][0][i] - out_images[b, i, :img.shape[0]] = img - out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] - out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] - _num_tiles.append(img.shape[0]) - out_num_tiles.append(_num_tiles) return MllamaImagePixelInputs( type="pixel_values", - data=out_images, - aspect_ratio_ids=out_ar_ids, - aspect_ratio_mask=out_ar_mask, + data=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, ) if image_embeds is not None: @@ -1312,7 +1310,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): batch_token_ids.append(token_ids[start:start + seq_len]) start += seq_len sparse_mask = [ - get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) + get_cross_attention_token_mask(t, self.image_token_id) for t in batch_token_ids ] @@ -1384,8 +1382,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): # block manager to allocate blocks for those images only. # See input_processor_for_mllama() for more details. num_tiles_tensor = kwargs.pop("num_tiles") - num_tiles = [t[0].tolist() for t in num_tiles_tensor] - num_tokens_per_tile = (self.image_size // 14)**2 + 1 + num_tiles = [t.tolist() for t in num_tiles_tensor] + num_tokens_per_tile = calc_token_per_chunk(self.image_size) actual_encoder_seq_lens = [ sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles ] diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 5f9593ee8..25ca8d1e7 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict): For each modality, information about the placeholder tokens in :code:`prompt_token_ids`. """ + + +class MultiModalEncDecInputs(MultiModalInputs): + """ + Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor` + ready to be passed to vLLM internals. + """ + + encoder_prompt: str + """The processed encoder prompt text.""" + + encoder_prompt_token_ids: list[int] + """The processed token IDs of the encoder prompt.""" + + encoder_token_type_ids: NotRequired[list[int]] + """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d704fa59b..74479f5ff 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, - PlaceholderRange) +from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, + MultiModalKwargsItem, PlaceholderRange) from .parse import MultiModalDataItems, MultiModalDataParser if TYPE_CHECKING: @@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_hashes=mm_hashes, mm_placeholders=mm_placeholder_ranges, ) + + +class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): + + @abstractmethod + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """Create input prompt for the encoder.""" + raise NotImplementedError + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalEncDecInputs: + """ + Process multi-modal inputs to be used in vLLM. + The main processing steps are modified to fit encoder-decoder model: + 1. Create encoder prompt from input prompt text. + 2. Apply the HF processor on encoder prompt. + 3. Copy the input prompt text as decoder prompt inputs. + """ + encoder_prompt = self.create_encoder_prompt(prompt, mm_data) + encoder_inputs = super().apply( + encoder_prompt, + mm_data, + hf_processor_mm_kwargs, + ) + + # We assumed the decoder prompt text is copied from + # the original encoder prompt without extra process + tokenizer = self.info.get_tokenizer() + if isinstance(prompt, str): + decoder_prompt = prompt + decoder_prompt_ids = encode_tokens(tokenizer, + prompt, + add_special_tokens=False) + else: + decoder_prompt = decode_tokens(tokenizer, prompt) + decoder_prompt_ids = prompt + + mm_inputs = MultiModalEncDecInputs( + encoder_prompt=encoder_inputs["prompt"], + encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], + **encoder_inputs) + mm_inputs.update({ + "prompt": decoder_prompt, + "prompt_token_ids": decoder_prompt_ids + }) + return mm_inputs diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 5dd754854..81c92b38f 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]): hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def get_dummy_data(self, seq_len: int) -> DummyData: + def get_dummy_data( + self, + seq_len: int, + is_encoder_data: bool = False, + ) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData @@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]): total_len = len(prompt_token_ids) # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain multi-modal " - "inputs to fail during inference, even when the input text is " - "short. To avoid this, you should increase `max_model_len`, " - "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, - total_len, total_placeholders_by_modality) + if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: + if total_len > seq_len: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) return DummyData( seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), -- GitLab From 009439caeb3ae27d1d6c94e550eee13bbd0520af Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:52:41 -0800 Subject: [PATCH 108/253] Simplify logic of locating CUDART so file path (#13203) Signed-off-by: Lu Fang --- .../device_communicators/cuda_wrapper.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index bc2cfbf32..1d53b1c5b 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -5,7 +5,6 @@ convenient for use when we just need to call a few functions. """ import ctypes -import glob from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -62,29 +61,6 @@ def find_loaded_library(lib_name) -> Optional[str]: return path -def get_cudart_lib_path_from_env() -> Optional[str]: - """ - In some system, find_loaded_library() may not work. So we allow users to - specify the path through environment variable VLLM_CUDART_SO_PATH. - """ - cudart_so_env = envs.VLLM_CUDART_SO_PATH - if cudart_so_env is not None: - cudart_paths = [ - cudart_so_env, - ] - for path in cudart_paths: - file_paths = glob.glob(path) - if len(file_paths) > 0: - logger.info( - "Found cudart library at %s through env var" - "VLLM_CUDART_SO_PATH=%s", - file_paths[0], - cudart_so_env, - ) - return file_paths[0] - return None - - class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) @@ -131,7 +107,7 @@ class CudaRTLibrary: if so_file is None: so_file = find_loaded_library("libcudart") if so_file is None: - so_file = get_cudart_lib_path_from_env() + so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var assert so_file is not None, \ ( "libcudart is not loaded in the current process, " -- GitLab From 60c68df6d1a4421480710c2303c8814724f90bb5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 12 Feb 2025 23:10:28 -0800 Subject: [PATCH 109/253] [Build] Automatically use the wheel of the base commit with Python-only build (#13178) --- .../installation/gpu/cuda.inc.md | 16 ++++++++--- setup.py | 27 ++++++++++++++++--- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/source/getting_started/installation/gpu/cuda.inc.md index 5c2ea30db..948bdbffb 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/source/getting_started/installation/gpu/cuda.inc.md @@ -89,12 +89,22 @@ cd vllm VLLM_USE_PRECOMPILED=1 pip install --editable . ``` -This will download the [latest nightly wheel](https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl) and use the compiled libraries from there in the installation. +This command will do the following: +1. Look for the current branch in your vLLM clone. +2. Identify the corresponding base commit in the main branch. +3. Download the pre-built wheel of the base commit. +4. Use its compiled libraries in the installation. -The `VLLM_PRECOMPILED_WHEEL_LOCATION` environment variable can be used instead of `VLLM_USE_PRECOMPILED` to specify a custom path or URL to the wheel file. For example, to use the [0.6.1.post1 PyPi wheel](https://pypi.org/project/vllm/#files): +:::{note} +1. If you change C++ or kernel code, you cannot use Python-only build; otherwise you will see an import error about library not found or undefined symbol. +2. If you rebase your dev branch, it is recommended to uninstall vllm and re-run the above command to make sure your libraries are up to date. +::: + +In case you see an error about wheel not found when running the above command, it might be because the commit you based on in the main branch was just merged and the wheel is being built. In this case, you can wait for around an hour to try again, or manually assign the previous commit in the installation using the `VLLM_PRECOMPILED_WHEEL_LOCATION` environment variable. ```console -export VLLM_PRECOMPILED_WHEEL_LOCATION=https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl +export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch +export VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl pip install --editable . ``` diff --git a/setup.py b/setup.py index 27e5aab76..5a74a44c0 100755 --- a/setup.py +++ b/setup.py @@ -268,15 +268,34 @@ class cmake_build_ext(build_ext): class repackage_wheel(build_ext): """Extracts libraries and other files from an existing wheel.""" - default_wheel = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" - def run(self) -> None: - wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", - self.default_wheel) + def get_base_commit_in_main_branch(self) -> str: + import subprocess + + try: + current_branch = subprocess.check_output( + ["git", "branch", "--show-current"]).decode("utf-8").strip() + + base_commit = subprocess.check_output( + ["git", "merge-base", "main", + current_branch]).decode("utf-8").strip() + return base_commit + except Exception as err: + logger.warning( + "Failed to get the base commit in the main branch. " + "Using the nightly wheel. The libraries in this " + "wheel may not be compatible with your dev branch: %s", err) + return "nightly" + def run(self) -> None: assert _is_cuda( ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) + if wheel_location is None: + base_commit = self.get_base_commit_in_main_branch() + wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + import zipfile if os.path.isfile(wheel_location): -- GitLab From 04f50ad9d1e1fd3f2d9594d5dbfd1d51d37bfea0 Mon Sep 17 00:00:00 2001 From: LikeSundayLikeRain Date: Thu, 13 Feb 2025 02:11:26 -0500 Subject: [PATCH 110/253] [Bugfix] deepseek_r1_reasoning_parser put reason content in wrong field in certain edge case (#13097) --- .../test_deepseekr1_reasoning_parser.py | 10 +++++----- .../reasoning_parsers/deepseek_r1_reasoning_parser.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index fdadb2e21..ea504f3d0 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -24,10 +24,10 @@ COMPLETE_REASONING = { "reasoning_content": "This is a reasoning section", "content": None, } -NO_REASONING = { +NO_CONTENT = { "output": "This is content", - "reasoning_content": None, - "content": "This is content", + "reasoning_content": "This is content", + "content": None, } NO_REASONING_STREAMING = { "output": "This is a reasoning section", @@ -98,8 +98,8 @@ TEST_CASES = [ ), pytest.param( False, - NO_REASONING, - id="no_reasoning_token", + NO_CONTENT, + id="no_content_token", ), pytest.param( True, diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index 33bba0488..e5ab6e6b2 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -128,7 +128,7 @@ class DeepSeekR1ReasoningParser(ReasoningParser): # Thus we assume the reasoning content is always at the start. # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f if self.think_end_token not in model_output: - return None, model_output + return model_output, None else: # Add a start token if it's missing to keep compatibility. if self.think_start_token not in model_output: -- GitLab From d46d490c275091b4900ce10aa21032c222e85180 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 02:12:21 -0500 Subject: [PATCH 111/253] [Frontend] Move CLI code into vllm.cmd package (#12971) --- docs/source/design/arch_overview.md | 2 +- setup.py | 2 +- vllm/entrypoints/cli/__init__.py | 0 vllm/entrypoints/cli/main.py | 79 ++++++++++ vllm/entrypoints/cli/openai.py | 172 +++++++++++++++++++++ vllm/entrypoints/cli/serve.py | 63 ++++++++ vllm/entrypoints/cli/types.py | 24 +++ vllm/entrypoints/openai/api_server.py | 3 +- vllm/scripts.py | 208 +------------------------- 9 files changed, 348 insertions(+), 205 deletions(-) create mode 100644 vllm/entrypoints/cli/__init__.py create mode 100644 vllm/entrypoints/cli/main.py create mode 100644 vllm/entrypoints/cli/openai.py create mode 100644 vllm/entrypoints/cli/serve.py create mode 100644 vllm/entrypoints/cli/types.py diff --git a/docs/source/design/arch_overview.md b/docs/source/design/arch_overview.md index 04886e598..7bed0a001 100644 --- a/docs/source/design/arch_overview.md +++ b/docs/source/design/arch_overview.md @@ -66,7 +66,7 @@ This server can be started using the `vllm serve` command. vllm serve ``` -The code for the `vllm` CLI can be found in . +The code for the `vllm` CLI can be found in . Sometimes you may see the API server entrypoint used directly instead of via the `vllm` CLI command. For example: diff --git a/setup.py b/setup.py index 5a74a44c0..7243a2ab3 100755 --- a/setup.py +++ b/setup.py @@ -689,7 +689,7 @@ setup( package_data=package_data, entry_points={ "console_scripts": [ - "vllm=vllm.scripts:main", + "vllm=vllm.entrypoints.cli.main:main", ], }, ) diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py new file mode 100644 index 000000000..e94d9a056 --- /dev/null +++ b/vllm/entrypoints/cli/main.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 + +# The CLI entrypoint to vLLM. +import os +import signal +import sys + +import vllm.entrypoints.cli.openai +import vllm.entrypoints.cli.serve +import vllm.version +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +CMD_MODULES = [ + vllm.entrypoints.cli.openai, + vllm.entrypoints.cli.serve, +] + + +def register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def env_setup(): + # The safest multiprocessing method is `spawn`, as the default `fork` method + # is not compatible with some accelerators. The default method will be + # changing in future versions of Python, so we should use it explicitly when + # possible. + # + # We only set it here in the CLI entrypoint, because changing to `spawn` + # could break some existing code using vLLM as a library. `spawn` will cause + # unexpected behavior if the code is not protected by + # `if __name__ == "__main__":`. + # + # References: + # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors + # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders + if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: + logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def main(): + env_setup() + + parser = FlexibleArgumentParser(description="vLLM CLI") + parser.add_argument('-v', + '--version', + action='version', + version=vllm.version.__version__) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults( + dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py new file mode 100644 index 000000000..73df900f6 --- /dev/null +++ b/vllm/entrypoints/cli/openai.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# Commands that act as an interactive OpenAI API client + +import argparse +import os +import signal +import sys +from typing import List, Optional, Tuple + +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.utils import FlexibleArgumentParser + + +def _register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]: + _register_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + return model_name, openai_client + + +def chat(system_prompt: Optional[str], model_name: str, + client: OpenAI) -> None: + conversation: List[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +class ChatCommand(CLISubcommand): + """The `chat` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "chat" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + system_prompt = args.system_prompt + conversation: List[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + return chat_parser + + +class CompleteCommand(CLISubcommand): + """The `complete` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "complete" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server"), + usage="vllm complete [options]") + _add_query_options(complete_parser) + return complete_parser + + +def cmd_init() -> List[CLISubcommand]: + return [ChatCommand(), CompleteCommand()] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py new file mode 100644 index 000000000..1afead8a1 --- /dev/null +++ b/vllm/entrypoints/cli/serve.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from typing import List + +import uvloop + +from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.utils import FlexibleArgumentParser + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "serve" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + # The default value of `--model` + if args.model != EngineArgs.model: + raise ValueError( + "With `vllm serve`, you should provide the model as a " + "positional argument instead of via the `--model` option.") + + # EngineArgs expects the model name to be passed as --model. + args.model = args.model_tag + + uvloop.run(run_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_parsed_serve_args(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server", + usage="vllm serve [options]") + serve_parser.add_argument("model_tag", + type=str, + help="The model tag to serve") + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file." + "Must be a YAML with the following options:" + "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" + ) + + return make_arg_parser(serve_parser) + + +def cmd_init() -> List[CLISubcommand]: + return [ServeSubcommand()] diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py new file mode 100644 index 000000000..f739a68c5 --- /dev/null +++ b/vllm/entrypoints/cli/types.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from vllm.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI argument handlers.""" + + name: str + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + raise NotImplementedError("Subclasses should implement this method") + + def validate(self, args: argparse.Namespace) -> None: + # No validation by default + pass + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + raise NotImplementedError("Subclasses should implement this method") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7..127ee9414 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -901,7 +901,8 @@ async def run_server(args, **uvicorn_kwargs) -> None: if __name__ == "__main__": # NOTE(simon): - # This section should be in sync with vllm/scripts.py for CLI entrypoints. + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) diff --git a/vllm/scripts.py b/vllm/scripts.py index 467cab28f..7e569d2d2 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,210 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -# The CLI entrypoint to vLLM. -import argparse -import os -import signal -import sys -from typing import List, Optional - -import uvloop -from openai import OpenAI -from openai.types.chat import ChatCompletionMessageParam - -import vllm.version -from vllm.engine.arg_utils import EngineArgs -from vllm.entrypoints.openai.api_server import run_server -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.cli.main import main as vllm_main from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser logger = init_logger(__name__) -def register_signal_handlers(): - - def signal_handler(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTSTP, signal_handler) - - -def serve(args: argparse.Namespace) -> None: - # The default value of `--model` - if args.model != EngineArgs.model: - raise ValueError( - "With `vllm serve`, you should provide the model as a " - "positional argument instead of via the `--model` option.") - - # EngineArgs expects the model name to be passed as --model. - args.model = args.model_tag - - uvloop.run(run_server(args)) - - -def interactive_cli(args: argparse.Namespace) -> None: - register_signal_handlers() - - base_url = args.url - api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") - openai_client = OpenAI(api_key=api_key, base_url=base_url) - - if args.model_name: - model_name = args.model_name - else: - available_models = openai_client.models.list() - model_name = available_models.data[0].id - - print(f"Using model: {model_name}") - - if args.command == "complete": - complete(model_name, openai_client) - elif args.command == "chat": - chat(args.system_prompt, model_name, openai_client) - - -def complete(model_name: str, client: OpenAI) -> None: - print("Please enter prompt to complete:") - while True: - input_prompt = input("> ") - - completion = client.completions.create(model=model_name, - prompt=input_prompt) - output = completion.choices[0].text - print(output) - - -def chat(system_prompt: Optional[str], model_name: str, - client: OpenAI) -> None: - conversation: List[ChatCompletionMessageParam] = [] - if system_prompt is not None: - conversation.append({"role": "system", "content": system_prompt}) - - print("Please enter a message for the chat model:") - while True: - input_message = input("> ") - conversation.append({"role": "user", "content": input_message}) - - chat_completion = client.chat.completions.create(model=model_name, - messages=conversation) - - response_message = chat_completion.choices[0].message - output = response_message.content - - conversation.append(response_message) # type: ignore - print(output) - - -def _add_query_options( - parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument( - "--url", - type=str, - default="http://localhost:8000/v1", - help="url of the running OpenAI-Compatible RESTful API server") - parser.add_argument( - "--model-name", - type=str, - default=None, - help=("The model name used in prompt completion, default to " - "the first model in list models API call.")) - parser.add_argument( - "--api-key", - type=str, - default=None, - help=( - "API key for OpenAI services. If provided, this api key " - "will overwrite the api key obtained through environment variables." - )) - return parser - - -def env_setup(): - # The safest multiprocessing method is `spawn`, as the default `fork` method - # is not compatible with some accelerators. The default method will be - # changing in future versions of Python, so we should use it explicitly when - # possible. - # - # We only set it here in the CLI entrypoint, because changing to `spawn` - # could break some existing code using vLLM as a library. `spawn` will cause - # unexpected behavior if the code is not protected by - # `if __name__ == "__main__":`. - # - # References: - # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods - # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing - # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors - # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders - if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: - logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - +# Backwards compatibility for the move from vllm.scripts to +# vllm.entrypoints.cli.main def main(): - env_setup() - - parser = FlexibleArgumentParser(description="vLLM CLI") - parser.add_argument('-v', - '--version', - action='version', - version=vllm.version.__version__) - - subparsers = parser.add_subparsers(required=True, dest="subparser") - - serve_parser = subparsers.add_parser( - "serve", - help="Start the vLLM OpenAI Compatible API server", - usage="vllm serve [options]") - serve_parser.add_argument("model_tag", - type=str, - help="The model tag to serve") - serve_parser.add_argument( - "--config", - type=str, - default='', - required=False, - help="Read CLI options from a config file." - "Must be a YAML with the following options:" - "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" - ) - - serve_parser = make_arg_parser(serve_parser) - serve_parser.set_defaults(dispatch_function=serve) - - complete_parser = subparsers.add_parser( - "complete", - help=("Generate text completions based on the given prompt " - "via the running API server"), - usage="vllm complete [options]") - _add_query_options(complete_parser) - complete_parser.set_defaults(dispatch_function=interactive_cli, - command="complete") - - chat_parser = subparsers.add_parser( - "chat", - help="Generate chat completions via the running API server", - usage="vllm chat [options]") - _add_query_options(chat_parser) - chat_parser.add_argument( - "--system-prompt", - type=str, - default=None, - help=("The system prompt to be added to the chat template, " - "used for models that support system prompts.")) - chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") - - args = parser.parse_args() - if args.subparser == "serve": - validate_parsed_serve_args(args) - - # One of the sub commands should be executed. - if hasattr(args, "dispatch_function"): - args.dispatch_function(args) - else: - parser.print_help() - - -if __name__ == "__main__": - main() + logger.warning("vllm.scripts.main() is deprecated. Please re-install " + "vllm or use vllm.entrypoints.cli.main.main() instead.") + vllm_main() -- GitLab From cb944d5818398281d5e777824ee394d5455d8b8b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:13:08 -0800 Subject: [PATCH 112/253] Allow Unsloth Dynamic 4bit BnB quants to work (#12974) --- .../layers/quantization/bitsandbytes.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 889eda009..49d992d4c 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -133,8 +133,16 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): components = prefix.split('.') # Check if any of the skip modules exactly matches any component - return any(module_name in components - for module_name in llm_int8_skip_modules) + substr_check = any(module_name in components + for module_name in llm_int8_skip_modules) + + # Allow certain layers to not be quantized + set_components = set(".".join(components[:i + 1]) + for i in range(len(components))) + set_llm_int8_skip_modules = set(llm_int8_skip_modules) + prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 + + return substr_check or prefix_check class BitsAndBytesLinearMethod(LinearMethodBase): -- GitLab From 0ccd8769fbab6a22372dd77608293c1a80921812 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 02:45:38 -0500 Subject: [PATCH 113/253] [CI/Build] Allow ruff to auto-fix some issues (#13180) Signed-off-by: Russell Bryant --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 22b51afdc..f664b4c55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: rev: v0.9.3 hooks: - id: ruff - args: [--output-format, github] + args: [--output-format, github, --fix] exclude: 'vllm/third_party/.*' - repo: https://github.com/codespell-project/codespell rev: v2.4.0 -- GitLab From 9605c1256ece238c5702c7020e6806eb2c2ebeb0 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Thu, 13 Feb 2025 00:02:46 -0800 Subject: [PATCH 114/253] [V1][core] Implement pipeline parallel on Ray (#12996) --- tests/distributed/test_pipeline_parallel.py | 51 ++++++++++++++++----- vllm/executor/ray_utils.py | 11 ++++- vllm/v1/core/kv_cache_utils.py | 41 +++++++++++------ vllm/v1/engine/core.py | 19 +++++--- vllm/v1/executor/abstract.py | 12 ++--- vllm/v1/worker/gpu_model_runner.py | 16 ++++++- vllm/v1/worker/gpu_worker.py | 5 +- 7 files changed, 110 insertions(+), 45 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5d7cb9e40..6a54fb74b 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple): @dataclass class PPTestSettings: parallel_setups: List[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. distributed_backends: List[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: List[str] task: TaskOption test_options: PPTestOptions + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + @staticmethod def detailed( *, @@ -79,7 +92,9 @@ class PPTestSettings: eager_mode=True, chunked_prefill=False), ], - distributed_backends=["mp", "ray"], + # only ray is supported for V1 + distributed_backends=["mp", "ray", "ray"], + vllm_major_versions=["0", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, trust_remote_code=trust_remote_code, @@ -108,6 +123,7 @@ class PPTestSettings: chunked_prefill=False), ], distributed_backends=["mp"], + vllm_major_versions=["0"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, trust_remote_code=trust_remote_code, @@ -120,8 +136,9 @@ class PPTestSettings: opts = self.test_options for parallel_setup in self.parallel_setups: - for distributed_backend in self.distributed_backends: - yield (model_name, parallel_setup, distributed_backend, + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_name, parallel_setup, backend, vllm_major_version, self.task, opts) @@ -244,6 +261,7 @@ def _compare_tp( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available: int, @@ -296,10 +314,13 @@ def _compare_tp( if hf_overrides: common_args.extend(["--hf-overrides", hf_overrides]) - if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 - and chunked_prefill): - # Test Ray ADAG for a subset of the tests + specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill + if distributed_backend == "ray" and (vllm_major_version == "1" + or specific_case): + # For V1, test Ray ADAG for all the tests + # For V0, test Ray ADAG for a subset of the tests pp_env = { + "VLLM_USE_V1": vllm_major_version, "VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", @@ -348,8 +369,8 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_name) @@ -361,6 +382,7 @@ def test_tp_language_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -368,6 +390,7 @@ def test_tp_language_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, @@ -375,8 +398,8 @@ def test_tp_language_generation( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in EMBEDDING_MODELS.items() for params in settings.iter_params(model_name) @@ -388,6 +411,7 @@ def test_tp_language_embedding( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -395,6 +419,7 @@ def test_tp_language_embedding( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, @@ -402,8 +427,8 @@ def test_tp_language_embedding( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in MULTIMODAL_MODELS.items() for params in settings.iter_params(model_name) @@ -415,6 +440,7 @@ def test_tp_multimodal_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -422,6 +448,7 @@ def test_tp_multimodal_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 33c0a2580..8ad466a55 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -35,7 +35,7 @@ try: class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be - lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + lazily initialized after Ray sets CUDA_VISIBLE_DEVICES.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -118,7 +118,14 @@ try: ) -> "ModelRunnerOutput": self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" - output = self.worker.model_runner.execute_model(scheduler_output) + if isinstance(scheduler_output, tuple): + scheduler_output, intermediate_tensors = scheduler_output + else: + scheduler_output, intermediate_tensors = scheduler_output, None + output = self.worker.model_runner.execute_model( + scheduler_output, intermediate_tensors) + if isinstance(output, IntermediateTensors): + output = scheduler_output, output return output def override_env_vars(self, vars: Dict[str, str]): diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bddb482d2..6dec87d4d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> KVCacheConfig: + available_memory: int, + num_layers: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. @@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model available_memory: Memory available for KV cache in bytes. + num_layers: The number of layers in the model. Returns: The generated KVCacheConfig @@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, assert len(page_sizes) == 1 page_size = page_sizes.pop() - num_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: @@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config -def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> KVCacheConfig: +def get_kv_cache_configs(vllm_config: VllmConfig, + kv_cache_specs: List[KVCacheSpec], + available_memory: int) -> List[KVCacheConfig]: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_specs: The kv cache specs of the model available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfig + The generated KVCacheConfigs """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if is_kv_cache_type_uniform(kv_cache_spec): - # KV cache of all layers are the same, which is true for most models. - # Allocate the same amount of memory for each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) - else: - raise NotImplementedError + # Use the max number of layers to conservatively determine + # the number of blocks. + num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs) + kv_cache_configs = [] + for kv_cache_spec in kv_cache_specs: + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, + available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + kv_cache_configs.append( + _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory, + num_layers)) + else: + raise NotImplementedError + return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e4677681b..e19680355 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -16,7 +16,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx -from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) @@ -73,20 +73,25 @@ class EngineCore: start = time.time() # Get all kv cache needed by the model - kv_cache_spec = self.model_executor.get_kv_cache_spec() + kv_cache_specs = self.model_executor.get_kv_cache_specs() # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. - availble_gpu_memory = self.model_executor.determine_available_memory() + available_gpu_memory = self.model_executor.determine_available_memory() # Get the kv cache tensor size - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - availble_gpu_memory) - num_gpu_blocks = kv_cache_config.num_blocks + kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, + available_gpu_memory) + num_gpu_blocks_set = set(config.num_blocks + for config in kv_cache_configs) + assert len(num_gpu_blocks_set) == 1, ( + f"num_gpu_blocks need to be the same across workers, " + f"but they are different: {num_gpu_blocks_set}") + num_gpu_blocks = num_gpu_blocks_set.pop() num_cpu_blocks = 0 # Initialize kv cache and warmup the execution - self.model_executor.initialize(kv_cache_config) + self.model_executor.initialize(kv_cache_configs) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 093be09ae..d1ffc891a 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import List, Type from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase @@ -48,12 +48,12 @@ class Executor(ExecutorBase): f"{distributed_executor_backend}") return executor_class - def initialize(self, kv_cache_config: KVCacheConfig) -> None: + def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_cache", args=(kv_cache_config, )) + self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") def determine_available_memory(self) -> int: # in bytes @@ -63,11 +63,9 @@ class Executor(ExecutorBase): # operators can be applied to all workers. return min(output) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_specs(self) -> List[KVCacheSpec]: output = self.collective_rpc("get_kv_cache_spec") - for x in output: - assert x == output[0] - return output[0] + return output def execute_model( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b1eab613..5d8da7545 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,7 +12,7 @@ import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig -from vllm.distributed.parallel_state import graph_capture +from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, @@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def execute_model( self, scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> ModelRunnerOutput: batch_changed = self._update_states(scheduler_output) @@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): positions=positions, kv_caches=self.kv_caches, attn_metadata=None, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions[:num_tokens] + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, + dtype=self.model_config.dtype, + device=self.device) with set_forward_context(None, self.vllm_config): hidden_states = model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=None, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) # TODO(woosuk): Consider the memory usage of the sampler. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ad53f90b8..beedca05c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional import torch import torch.distributed @@ -194,8 +194,9 @@ class Worker: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: + def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + kv_cache_config = kv_cache_configs[self.rank] if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") -- GitLab From fa253f1a702511e35dbbe14ac7b144def7f12d7d Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 13 Feb 2025 16:31:37 +0800 Subject: [PATCH 115/253] [VLM] Remove input processor from clip and siglip (#13165) --- vllm/model_executor/models/clip.py | 149 ++------------------------- vllm/model_executor/models/siglip.py | 74 +------------ 2 files changed, 10 insertions(+), 213 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 1e784f5b4..547f62447 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,156 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union -import numpy as np import torch import torch.nn as nn -from PIL import Image from transformers import CLIPVisionConfig from vllm.attention.layer import MultiHeadAttention -from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceData from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs -def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: - assert image_size % patch_size == 0 - return image_size // patch_size - - -def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_clip_patch_grid_length(image_size=image_size, - patch_size=patch_size) - return grid_length * grid_length - - -def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: - return get_clip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) + 1 - - -def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: - return get_clip_image_feature_size(hf_config) - - -def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, - mm_key: str = "image"): - if image_feature_size_override is None: - image_feature_size = get_clip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - mm_key: - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_image_for_clip( - hf_config: CLIPVisionConfig, - num_images: int, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override - - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_video_for_clip( - hf_config: CLIPVisionConfig, - num_frames: int, - num_videos: int = 1, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - pil_frame = dummy_image_for_clip( - hf_config, - num_images=1, - image_width_override=image_width_override, - image_height_override=image_height_override) - np_frame = np.array(pil_frame["image"]) - mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) - video_data = [mm_data_per_video] * num_videos - mm_data = {"video": video_data} - return mm_data - - -def input_processor_for_clip( - model_config: ModelConfig, - hf_config: CLIPVisionConfig, - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[Union[int, List[int]]] = None, -): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - if image_feature_size_override is None: - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_clip_image_feature_size(hf_config) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - else: - image_feature_size = image_feature_size_override - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) - - class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): def get_num_image_tokens( @@ -159,10 +27,10 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): image_width: int, image_height: int, ) -> int: - return get_clip_image_feature_size(self.vision_config) + return self.get_patch_grid_length()**2 + 1 def get_max_image_tokens(self) -> int: - return get_max_clip_image_tokens(self.vision_config) + return self.get_patch_grid_length()**2 + 1 def get_image_size(self) -> int: return self.vision_config.image_size @@ -171,10 +39,9 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): return self.vision_config.patch_size def get_patch_grid_length(self) -> int: - return get_clip_patch_grid_length( - image_size=self.vision_config.image_size, - patch_size=self.vision_config.patch_size, - ) + image_size, patch_size = self.get_image_size(), self.get_patch_size() + assert image_size % patch_size == 0 + return image_size // patch_size # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa @@ -186,6 +53,7 @@ class CLIPVisionEmbeddings(nn.Module): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) @@ -197,8 +65,7 @@ class CLIPVisionEmbeddings(nn.Module): bias=False, ) - self.num_patches = get_clip_num_patches(image_size=self.image_size, - patch_size=self.patch_size) + self.num_patches = (self.image_size // self.patch_size)**2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index a81462f6f..ddae78d77 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,18 +3,15 @@ within a vision language model.""" import math -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union -import numpy as np import torch from PIL import Image from torch import nn from transformers import SiglipVisionConfig from vllm.attention.layer import MultiHeadAttention -from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -23,9 +20,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import SequenceData from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs @@ -93,71 +88,6 @@ def dummy_image_for_siglip( return {"image": image if num_images == 1 else [image] * num_images} -def dummy_video_for_siglip( - hf_config: SiglipVisionConfig, - num_frames: int, - num_videos: int = 1, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - pil_frame = dummy_image_for_siglip( - hf_config, - num_images=1, - image_width_override=image_width_override, - image_height_override=image_height_override) - np_frame = np.array(pil_frame["image"]) - mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) - video_data = [mm_data_per_video] * num_videos - mm_data = {"video": video_data} - return mm_data - - -def input_processor_for_siglip( - model_config: ModelConfig, - hf_config: SiglipVisionConfig, - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[Union[int, List[int]]] = None, -): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - if image_feature_size_override is None: - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_siglip_image_feature_size(hf_config) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - else: - image_feature_size = image_feature_size_override - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) - - class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): def get_num_image_tokens( -- GitLab From 578087e56c21b2c940135a7c059c8dd88c8961f5 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 13 Feb 2025 03:51:46 -0500 Subject: [PATCH 116/253] [Frontend] Pass pre-created socket to uvicorn (#13113) --- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/launcher.py | 9 ++++++--- vllm/entrypoints/openai/api_server.py | 13 ++++++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 96818507d..00793d4b9 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -127,6 +127,7 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, + sock=None, host=args.host, port=args.port, log_level=args.log_level, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 351a39525..79946a498 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -2,8 +2,9 @@ import asyncio import signal +import socket from http import HTTPStatus -from typing import Any +from typing import Any, Optional import uvicorn from fastapi import FastAPI, Request, Response @@ -17,7 +18,8 @@ from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, sock: Optional[socket.socket], + **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -34,7 +36,8 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): loop = asyncio.get_running_loop() - server_task = loop.create_task(server.serve()) + server_task = loop.create_task( + server.serve(sockets=[sock] if sock else None)) def signal_handler() -> None: # prevents the uvicorn signal handler to exit early diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 127ee9414..588a7781c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -10,7 +10,6 @@ import os import re import signal import socket -import sys import tempfile import uuid from argparse import Namespace @@ -831,6 +830,7 @@ def create_server_socket(addr: Tuple[str, int]) -> socket.socket: sock = socket.socket(family=family, type=socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.bind(addr) return sock @@ -878,8 +878,17 @@ async def run_server(args, **uvicorn_kwargs) -> None: model_config = await engine_client.get_model_config() await init_app_state(engine_client, model_config, app.state, args) + def _listen_addr(a: str) -> str: + if is_valid_ipv6_address(a): + return '[' + a + ']' + return a or "0.0.0.0" + + logger.info("Starting vLLM API server on http://%s:%d", + _listen_addr(sock_addr[0]), sock_addr[1]) + shutdown_task = await serve_http( app, + sock=sock, host=args.host, port=args.port, log_level=args.uvicorn_log_level, @@ -888,8 +897,6 @@ async def run_server(args, **uvicorn_kwargs) -> None: ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, - # Workaround to work on macOS - fd=sock.fileno() if sys.platform.startswith("darwin") else None, **uvicorn_kwargs, ) -- GitLab From fdcf64d3c6ad5ee2339a669214b43e4b735d0895 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 13 Feb 2025 03:43:24 -0800 Subject: [PATCH 117/253] [V1] Clarify input processing and multimodal feature caching logic (#13211) --- vllm/v1/engine/core.py | 16 +++++----- .../{mm_input_mapper.py => mm_input_cache.py} | 29 ++++++++++++------- vllm/v1/engine/processor.py | 20 +++++++++---- vllm/v1/worker/gpu_model_runner.py | 7 +++-- 4 files changed, 45 insertions(+), 27 deletions(-) rename vllm/v1/engine/{mm_input_mapper.py => mm_input_cache.py} (82%) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e19680355..4642ac177 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) -from vllm.v1.engine.mm_input_mapper import MMInputMapperServer +from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -65,7 +65,7 @@ class EngineCore: log_stats=self.log_stats, ) - self.mm_input_mapper_server = MMInputMapperServer( + self.mm_input_cache_server = MMInputCacheServer( vllm_config.model_config) def _initialize_kv_caches(self, @@ -102,13 +102,13 @@ class EngineCore: """Add request to the scheduler.""" if request.mm_hashes is not None: - # Here, if hash exists for an image, then it will be fetched - # from the cache, else it will be added to the cache. - # Note that the cache here is mirrored with the client side of the - # MM mapper, so anything that has a hash must have a HIT cache - # entry here as well. + # Here, if hash exists for a multimodal input, then it will be + # fetched from the cache, else it will be added to the cache. + # Note that the cache here is mirrored with the client cache, so + # anything that has a hash must have a HIT cache entry here + # as well. assert request.mm_inputs is not None - request.mm_inputs = self.mm_input_mapper_server.process_inputs( + request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_cache.py similarity index 82% rename from vllm/v1/engine/mm_input_mapper.py rename to vllm/v1/engine/mm_input_cache.py index 83a0d9db1..e1b6679c2 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -10,12 +10,18 @@ from vllm.utils import LRUCache logger = init_logger(__name__) -# The idea of MM preprocessor caching is based on having a client and a server, -# where the client executes in the frontend process (=P0) and the server in the -# core process (=P1). +# The idea of multimodal preprocessing caching is based on having a client and +# a server, where the client executes in the frontend process (=P0) and the +# server in the core process (=P1). # -# -- Client: Executes the MM mapper and performs caching of the results. -# -- Server: Performs caching of the results +# -- Client: +# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs. +# - Perform caching of the generated MultiModalKwargs. +# - This client can be deprecated once all mutimodal models migrate to use +# merged preprocessor with built-in caching functionality. +# +# -- Server: +# - Perform caching of the received MultiModalKwargs. # # The caching for both client and server is mirrored/similar, and this allows us # to avoid the serialization of "mm_inputs" (like pixel values) between @@ -27,7 +33,9 @@ logger = init_logger(__name__) MM_CACHE_SIZE = 256 -class MMInputMapperClient: +# TODO(ywang96): Deprecate this class once all multimodal models migrate to use +# merged preprocessor with built-in caching functionality. +class MMInputCacheClient: def __init__( self, @@ -54,7 +62,8 @@ class MMInputMapperClient: logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", self.mm_cache_hits / self.mm_cache_total) - # TODO: Support modalities beyond image. + # NOTE: process_inputs only supports image inputs since all multimodal + # models with other modalities have migrated to use merged preprocessor. def process_inputs( self, mm_data: MultiModalDataDict, @@ -95,7 +104,7 @@ class MMInputMapperClient: # Reuse precomputed input (for merged preprocessor) mm_input = precomputed_mm_inputs[input_id] else: - # Apply MM mapper + # Apply legacy input_mapper mm_input = self.multi_modal_input_mapper( {"image": [image_inputs[input_id]]}, mm_processor_kwargs=mm_processor_kwargs, @@ -114,13 +123,13 @@ class MMInputMapperClient: return ret_inputs -class MMInputMapperServer: +class MMInputCacheServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) - def process_inputs( + def get_and_update( self, mm_inputs: List[Optional[MultiModalKwargs]], mm_hashes: List[str], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 70876b03a..b7eee5a39 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,7 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.engine.mm_input_cache import MMInputCacheClient class Processor: @@ -46,7 +46,7 @@ class Processor: model_config) # Multi-modal (huggingface) input mapper - self.mm_input_mapper_client = MMInputMapperClient(model_config) + self.mm_input_cache_client = MMInputCacheClient(model_config) # Multi-modal hasher (for images) self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ @@ -106,16 +106,24 @@ class Processor: assert priority == 0, "vLLM V1 does not support priority at the moment." assert trace_headers is None, "vLLM V1 does not support tracing yet." - # Process inputs. + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + # 3. Apply prompt adapter to prompt token ids if one exists. preprocessed_inputs = self.input_preprocessor.preprocess( prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + # Process prompt and prompt token ids. + # Only applicable to multimodal models with legacy input processor. processed_inputs = self.input_processor(preprocessed_inputs) + self._validate_model_inputs(processed_inputs) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) if is_encoder_decoder_inputs(processed_inputs): decoder_inputs = SingletonInputsAdapter( @@ -200,8 +208,8 @@ class Processor: key=lambda mm_input: modality_order_dict[list( mm_input.modalities)[0]]) - # Apply mm input cache update (and input mapper if necessary). - sorted_mm_inputs = self.mm_input_mapper_client.process_inputs( + # Apply mm input cache update and legacy input mapper if one exists. + sorted_mm_inputs = self.mm_input_cache_client.process_inputs( mm_data=decoder_mm_data, mm_hashes=sorted_mm_hashes, mm_processor_kwargs=decoder_inputs.mm_processor_kwargs, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d8da7545..fa4bd81a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -95,9 +95,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - # NOTE: Initialized input mapper is only used for processing dummy + # NOTE: Initialized client is only used for processing dummy # multimodal data into multimodal kwargs for GPU memory profiling. - self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) + # Only applicable to multimodal models with legacy input mapper. + self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config) self.mm_input_mapper_profiling.use_cache = False encoder_compute_budget, encoder_cache_size = compute_encoder_budget( -- GitLab From c9d3ecf016deb619a3bce6d59049224e8eb12364 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 13 Feb 2025 20:34:00 +0800 Subject: [PATCH 118/253] [VLM] Merged multi-modal processor for Molmo (#12966) --- docs/source/models/supported_models.md | 2 +- .../decoder_only/language/test_models.py | 2 +- .../vision_language/test_models.py | 5 +- .../vision_language/vlm_utils/model_utils.py | 98 +- .../multimodal/processing/test_common.py | 2 + tests/models/registry.py | 1 + vllm/model_executor/models/molmo.py | 1023 +++++++++++------ vllm/multimodal/inputs.py | 80 +- vllm/utils.py | 35 +- 9 files changed, 750 insertions(+), 498 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 55b3f5235..86b746178 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `MolmoForCausalLM` * Molmo * T + I - * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. + * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. * ✅︎ * ✅︎ * ✅︎ diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index c6d524431..71e4a9f11 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -27,7 +27,7 @@ from ...utils import check_logprobs_close marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param( - "THUDM/chatglm3-6b", # ChatGLM (text-only) + "THUDM/chatglm3-6b", # chatglm (text-only) ), pytest.param( "meta-llama/Llama-3.2-1B-Instruct", # llama diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index b00ec6fa6..4ed61cfc9 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = { "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], test_type=(VLMTestType.IMAGE), - prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 + prompt_formatter=identity, max_model_len=4096, max_num_seqs=2, - image_size_factors=[(),(1.0, 1.0, 1.0)], - patch_hf_runner=model_utils.mlomo_patch_hf_runner, + patch_hf_runner=model_utils.molmo_patch_hf_runner, postprocess_inputs=model_utils.molmo_post_processor, ), # Tests for phi3v currently live in another file because of a bug in diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index ced891e1e..408ce9cfe 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -6,7 +6,7 @@ typically specific to a small subset of models. import re import types from pathlib import PosixPath -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from PIL.Image import Image @@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, - PromptImageInput, PromptVideoInput, _ImageAssets) -from ....utils import TokensTextLogprobs +from .....conftest import HfRunner, ImageAsset, _ImageAssets from .types import RunnerOutput @@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model -def _generate_greedy_logprobs_limit( - self, - prompts: List[str], - max_tokens: int, - num_logprobs: int, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, - **kwargs: Any, -) -> List[TokensTextLogprobs]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - # Process in batches for inference. - if len(all_inputs): - input_ids_lst = [] - images_lst = [] - images_input_idx_lst = [] - imges_masks_lst = [] - for inputs in all_inputs: - input_ids_lst.append(inputs["input_ids"]) - images_lst.append(inputs["images"]) - images_input_idx_lst.append(inputs["image_input_idx"]) - imges_masks_lst.append(inputs["image_masks"]) - batch_inputs = {} - batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) - batch_inputs['images'] = torch.cat(images_lst, dim=0) - batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, - dim=0) - batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) - - outputs = self.model.generate_from_batch( - batch=self.wrap_device(batch_inputs, - device=self.model.device.type), - generation_config=GenerationConfig( - max_new_tokens=max_tokens, - stop_strings="<|endoftext|>", - do_sample=False, - ), - tokenizer=self.tokenizer, - output_hidden_states=True, - return_dict_in_generate=True, - ) - - all_logprobs: List[List[Dict[int, float]]] = [] - all_output_ids: List[List[int]] = [] - all_output_strs: List[str] = [] - - for index in range(len(all_inputs)): - ( - seq_logprobs_lst, - output_len, - ) = self._hidden_states_to_logprobs(outputs.hidden_states, - num_logprobs) - all_logprobs.append(seq_logprobs_lst) - seq_ids = outputs.sequences[index] - output_ids = seq_ids[-output_len:] - all_output_ids.append(output_ids.tolist()) - all_output_strs.append(self.tokenizer.decode(output_ids)) - outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - - -####### Molmo-specific HuggingFace runner patchers -def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Molmo.""" hf_processor = hf_model.processor @@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: hf_model.processor = _processor - setattr( # noqa: B010 - hf_model, - "generate_greedy_logprobs_limit", - types.MethodType(_generate_greedy_logprobs_limit, hf_model), - ) + def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): + batch = { + k: kwargs.pop(k) + for k in ("input_ids", "images", "image_input_idx", "image_masks") + if k in kwargs + } + + return self.generate_from_batch( + batch, + generation_config=GenerationConfig( + max_new_tokens=max_new_tokens, + stop_strings="<|endoftext|>", + do_sample=do_sample, + ), + **kwargs, + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 67ef8b17a..88dcc32f4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -168,6 +168,8 @@ def _test_processing_correctness( "mistral-community/pixtral-12b", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", + "allenai/Molmo-7B-D-0924", + "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", diff --git a/tests/models/registry.py b/tests/models/registry.py index 7b1db5549..66a487ca6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", + extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b524a1497..feb585022 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,18 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 import math -import re -from array import array from dataclasses import dataclass -from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict +from functools import cached_property, partial +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union, cast) +import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F from einops import rearrange -from PIL import Image -from torch import nn -from torch.nn import functional as F -from transformers import PretrainedConfig +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + TensorType) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.attention.layer import MultiHeadAttention @@ -22,8 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) @@ -40,15 +40,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) -from vllm.transformers_utils.processor import get_processor +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.utils import JSONTree, json_map_leaves from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -56,38 +62,39 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 -DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 -DEFAULT_IM_START_TOKEN_ID = 152067 -DEFAULT_IM_END_TOKEN_ID = 152064 -DEFAULT_IM_COL_TOKEN_ID = 152065 +IMAGE_PATCH_TOKEN = "" +IM_COL_TOKEN = "" +IM_START_TOKEN = "" +IM_END_TOKEN = "" +POOLING_SIZE = 2 class MolmoImageInputs(TypedDict): - images: torch.Tensor - """Shape: - `(batch_size, num_crops, num_patch, patch_dim)` - """ + images: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" - image_input_idx: torch.Tensor - """Shape: - `(batch_size, num_crops, num_patch)` - """ + image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] + """Shape: `(batch_size, num_crops, num_patch)`""" - seq_len: torch.Tensor - """Shape: - `(batch_size, )` + feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] """ + A boolean mask indicating which image features correspond + to patch tokens. - image_masks: Optional[torch.Tensor] - """Shape: - `(batch_size, num_crops, num_patch)` + Shape: `(batch_size, num_crops, num_patch)` """ - image_start_end: Tuple[int, int] - """Starting and ending index of placeholder - tokens + embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_embeds)` """ + num_crops: torch.Tensor + """Shape: `(batch_size, num_images)`""" + @dataclass class VisionBackboneConfig: @@ -335,7 +342,7 @@ class VisionTransformer(nn.Module): def forward(self, x: torch.Tensor, - patch_num: int = None) -> List[torch.Tensor]: + patch_num: Optional[int] = None) -> List[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -465,7 +472,7 @@ class MolmoAttention(nn.Module): return output -class LanuageModelMLP(nn.Module): +class LanguageModelMLP(nn.Module): """Molmo's LLM mlp.""" def __init__(self, @@ -559,7 +566,7 @@ class MolmoDecoderLayer(nn.Module): prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = LanuageModelMLP(config, quant_config=quant_config) + self.mlp = LanguageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" @@ -638,8 +645,8 @@ class MolmoVisionBackbone(nn.Module): self.vit_layers = VIT_LAYERS self.image_num_patch = vision_config.image_num_patch self.llm_patches_per_crop = ( - (self.image_num_patch[0] + 1) // 2, - (self.image_num_patch[1] + 1) // 2, + (self.image_num_patch[0] + 1) // POOLING_SIZE, + (self.image_num_patch[1] + 1) // POOLING_SIZE, ) self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) @@ -723,19 +730,19 @@ class MolmoVisionBackbone(nn.Module): image_features = image_features.reshape( (batch_size, num_image) + self.image_num_patch + (-1, ), ) - if self.image_num_patch[0] % 2 == 1: - # Pad so we can still pool 2x2 patches + if (missing_w := self.image_num_patch[0] % POOLING_SIZE): + # Padding for image pooling (see below) image_features = F.pad( image_features, - (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), + (0, 0, 0, missing_w, 0, missing_w, 0, 0, 0, 0), ) # image pooling image_features = rearrange( image_features, 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', - dh=2, - dw=2, + dh=POOLING_SIZE, + dw=POOLING_SIZE, ) query = image_features.mean(-2, keepdim=True) @@ -888,249 +895,513 @@ class MolmoModel(nn.Module): return loaded_params -cached_get_processor = lru_cache(get_processor) +def _lowest_multiple(x: int, k: int) -> int: + return (x // k) * k + +def get_num_patches( + num_tiles: int, + *, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> int: + if num_tiles == 1: + return _lowest_multiple(crop_patches + pooling_size - 1, pooling_size) -def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, - right_margin: int, pooling_size: int) -> int: crop_window_patches = crop_patches - (left_margin + right_margin) - if num_tiles > 1: - left_crop_window_patches = (crop_window_patches + left_margin + - pooling_size - - 1) // pooling_size * pooling_size - middle_crop_window_patches = (crop_window_patches + pooling_size - - 1) // pooling_size * pooling_size - right_crop_window_patches = (crop_window_patches + right_margin + - pooling_size - - 1) // pooling_size * pooling_size - return left_crop_window_patches + ( - num_tiles - - 2) * middle_crop_window_patches + right_crop_window_patches - else: - single_crop_window_patches = (crop_patches + pooling_size - - 1) // pooling_size * pooling_size - return single_crop_window_patches - - -def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, - left_margin: int, right_margin: int, pooling_size: int) -> int: - h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, - pooling_size) - w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, - pooling_size) - per_row = w // pooling_size + 1 - joint = per_row * (h // pooling_size) + 2 - image_token_length = (crop_patches + pooling_size - 1) // pooling_size - resize = (image_token_length + 1) * image_token_length + 2 - return resize + joint - - -def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, - right_margin: int, pooling_size: int) -> int: - tilings = [] - for i in range(1, max_crops + 1): - for j in range(1, max_crops + 1): - if i * j <= max_crops: - tilings.append((i, j)) - tokens = [ - get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, - right_margin, pooling_size) for i in range(len(tilings)) - ] - return max(tokens) - - -def get_max_molmo_image_tokens(ctx: InputContext) -> int: - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - image_processor = processor.image_processor - max_llm_image_tokens = get_max_tokens( - image_processor.max_crops, - image_processor.base_image_input_size[0] // - image_processor.image_patch_size, - image_processor.overlap_margins[0], - image_processor.overlap_margins[1], - 2, + + left_num = _lowest_multiple( + crop_window_patches + left_margin + pooling_size - 1, + pooling_size, + ) + middle_num = _lowest_multiple( + crop_window_patches + pooling_size - 1, + pooling_size, + ) + right_num = _lowest_multiple( + crop_window_patches + right_margin + pooling_size - 1, + pooling_size, ) - return max_llm_image_tokens + return left_num + (num_tiles - 2) * middle_num + right_num + + +def get_patches_grid_size( + *, + tiling_h: int, + tiling_w: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int, +) -> tuple[int, int]: + nrows = get_num_patches( + tiling_h, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + ncols = get_num_patches( + tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) -# NOTE: preprocessing for the image data has been included in the -# 'input_processor_for_molmo' function -def image_input_mapper_for_molmo( - ctx: InputContext, - data: object, -): - if isinstance(data, list): - assert len(data) == 1, "Molmo supports only one image per prompt." - data = data[0] - - return MultiModalKwargs(data) - - -def dummy_data_for_molmo(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - image_processor = processor.image_processor - - base_image_input_d = image_processor.image_patch_size - left_margin, right_margin = image_processor.overlap_margins - max_crops = image_processor.max_crops - - # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 - max_llm_image_tokens = get_max_molmo_image_tokens(ctx) - if seq_len - max_llm_image_tokens - 1 < 0: - raise RuntimeError( - f"Molmo cannot process {max_crops} crops in a prompt, " - "please increase max_model_len or reduce number of crops") - - # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 - tiling = (max_crops, 1) - total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = image_processor.base_image_input_size[ - 0] // base_image_input_d - crop_window_patches = crop_patches - (right_margin + left_margin) - crop_window_size = crop_window_patches * base_image_input_d - - h = crop_window_size * tiling[0] + total_margin_pixels - w = crop_window_size * tiling[1] + total_margin_pixels - - dummy_image = Image.new("RGB", (w, h), color="red") - - out = processor.process("dummy prompt", dummy_image) - - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - out["input_ids"][:1 + max_llm_image_tokens]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - max_llm_image_tokens - 1) - dummy_seqdata = SequenceData(token_ids) - dummy_imgdata = { - "images": out["images"], - "image_input_idx": out["image_input_idx"], - } - if "image_masks" in out: - dummy_imgdata["image_masks"] = out["image_masks"] - dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - size = 0 - offset = -1 - for i in range(len(token_ids)): - if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, - DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, - DEFAULT_IM_COL_TOKEN_ID): - if offset < 0: - offset = i - size += 1 - dummy_imgdata["image_start_end"] = (offset, offset + size) - return DummyData(seq_data=dummy_seqdata, - multi_modal_data={"image": dummy_imgdata}, - multi_modal_placeholders={ - "image": - [PlaceholderRange(offset=offset, length=size)] - }) - - -def pad_images( - max_total_crops: int, - images: torch.Tensor, - image_input_idx: torch.Tensor, - image_masks: Optional[torch.Tensor] = None, + return nrows, ncols + + +def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: + tilings = [(i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) if i * j <= max_num] + return sorted(tilings, key=lambda x: x[0] * x[1]) + + +def select_tiling( + *, + height: int, + width: int, + patch_size: int, + max_num_patches: int, ): - n = max_total_crops - images.shape[0] - images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) - image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) - if image_masks is not None: - image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) - return images, image_input_idx, image_masks - - -def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): - prompt = inputs.get("prompt") - multi_modal_data = inputs.get("multi_modal_data") - image = None if multi_modal_data is None else multi_modal_data.get("image") - - model_config = ctx.model_config - processor = cached_get_processor( - ctx.model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=ctx.model_config.code_revision) - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - - # NOTE: message formatting for raw text prompt is only applied for - # offline inference; for online serving, the prompt is always in - # instruction format and tokenized. - if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", - prompt): - out = processor.process(prompt, image, message_format="none") - elif prompt is not None: - out = processor.process(prompt, image) + tilings = get_candidate_tilings(max_num_patches) + candidate_tilings = np.array(tilings, dtype=np.int32) + candidate_resolutions = candidate_tilings * patch_size + + original_size = np.array([height, width], dtype=np.float32) + required_scale_d = candidate_resolutions.astype(np.float32) / original_size + required_scale = required_scale_d.min(axis=-1, keepdims=True) + + if (required_scale < 1).all(): + ix = required_scale.argmax() else: - out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - - # If there is no image, return directly. - if image is None: - new_prompt_token_ids = out["input_ids"].tolist() - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( - prompt_token_ids=new_prompt_token_ids, - prompt=prompt, + ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin() + + return candidate_tilings[ix] + + +class MolmoProcessorWrapper: + """ + Wraps :class:`MolmoProcessor` so that it can be called directly. + + The original definition can be found here: + https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py + """ + + def __init__(self, processor: ProcessorMixin): + super().__init__() + + self.processor = processor + + @cached_property + def vocab(self) -> dict[str, int]: + return self.processor.tokenizer.vocab # type: ignore + + @cached_property + def max_crops(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + max_crops = image_processor.max_crops + assert isinstance(max_crops, int) + + return max_crops + + @cached_property + def base_image_input_size(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + base_image_input_size = image_processor.base_image_input_size + if isinstance(base_image_input_size, int): + return base_image_input_size, base_image_input_size + + return tuple(base_image_input_size) + + @cached_property + def image_patch_size(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_patch_size = image_processor.image_patch_size + assert isinstance(image_patch_size, int) + + return image_patch_size + + @cached_property + def overlap_margins(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + left_margin, right_margin = image_processor.overlap_margins + assert isinstance(left_margin, int) + assert isinstance(right_margin, int) + + return left_margin, right_margin + + @cached_property + def image_token_length_w(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_w = image_processor.image_token_length_w + assert isinstance(image_token_length_w, int) + + return image_token_length_w + + @cached_property + def image_token_length_h(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_token_length_h = image_processor.image_token_length_h + assert isinstance(image_token_length_h, int) + + return image_token_length_h + + @property + def message_format(self) -> Optional[str]: + return "role" + + @property + def always_start_with_space(self) -> bool: + return True + + @cached_property + def image_patch_id(self) -> int: + return self.vocab[IMAGE_PATCH_TOKEN] + + @cached_property + def im_col_id(self) -> int: + return self.vocab[IM_COL_TOKEN] + + @cached_property + def im_start_id(self) -> int: + return self.vocab[IM_START_TOKEN] + + @cached_property + def im_end_id(self) -> int: + return self.vocab[IM_END_TOKEN] + + @property + def pooling_size(self) -> int: + return POOLING_SIZE + + def select_tiling( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + max_crops = self.max_crops + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + tiling_h, tiling_w = select_tiling( + height=image_height - total_margin_pixels, + width=image_width - total_margin_pixels, + patch_size=crop_window_size, + max_num_patches=max_crops, ) - image_processor = processor.image_processor - max_total_crops = 1 + image_processor.max_crops - images, image_input_idx, image_masks = pad_images( - max_total_crops, - out["images"], - out["image_input_idx"], - out.get("image_masks"), - ) - image_data = dict( - images=images, - image_input_idx=image_input_idx, - ) - if image_masks is not None: - image_data["image_masks"] = image_masks - - new_prompt_token_ids = out["input_ids"].tolist() - image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), - dtype=torch.long) - - multi_modal_data = dict(image=image_data) - size = 0 - offset = -1 - for i in range(len(new_prompt_token_ids)): - if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, - DEFAULT_IM_START_TOKEN_ID, - DEFAULT_IM_END_TOKEN_ID, - DEFAULT_IM_COL_TOKEN_ID): - if offset < 0: - offset = i - size += 1 - image_data["image_start_end"] = (offset, offset + size) - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( - prompt_token_ids=new_prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={ - "image": [PlaceholderRange(offset=offset, length=size)] - }, - ) + return tiling_w, tiling_h + + def get_patches_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + pooling_size = self.pooling_size + + crop_patches = base_image_input_size[0] // base_image_input_d + tiling_w, tiling_h = self.select_tiling( + image_height=image_height, + image_width=image_width, + ) + + nrows, ncols = get_patches_grid_size( + tiling_h=tiling_h, + tiling_w=tiling_w, + crop_patches=crop_patches, + left_margin=left_margin, + right_margin=right_margin, + pooling_size=pooling_size, + ) + + return ncols, nrows + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + outputs = self.processor.process( # type: ignore + text, images, **kwargs) + + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + input_ids: torch.Tensor = outputs.pop("input_ids") + outputs["input_ids"] = input_ids.unsqueeze(0) + + image_input_idx = outputs.pop("image_input_idx", None) + if image_input_idx is not None: + input_is_patch = input_ids == self.image_patch_id + image_input_idx_flat: torch.Tensor = image_input_idx.view(-1) + image_valid_flat = image_input_idx_flat >= 0 + feat_is_patch_flat = image_valid_flat.clone() + feat_is_patch_flat[image_valid_flat] = ( + input_is_patch[image_input_idx_flat[image_valid_flat]]) + feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape) + + input_is_embed = torch.isin( + input_ids, + torch.tensor([ + self.image_patch_id, + self.im_col_id, + self.im_start_id, + self.im_end_id, + ]), + ) + embed_ids = input_ids[input_is_embed] + embed_is_patch = embed_ids == self.image_patch_id + assert embed_is_patch.sum() == feat_is_patch.sum() + tilings = [ + self.select_tiling( + image_width=image.size[0], + image_height=image.size[1], + ) for image in images + ] + # For each image: tiling_h * tiling_w + extra + num_crops = torch.tensor(tilings).prod(-1) + 1 + assert num_crops.sum() == len(feat_is_patch) -@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) -@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) + outputs["feat_is_patch"] = feat_is_patch + outputs["embed_is_patch"] = embed_is_patch + outputs["num_crops"] = num_crops + outputs["img_patch_id"] = self.image_patch_id + + return BatchFeature(outputs, tensor_type=return_tensors) + + +class MolmoProcessingInfo(BaseProcessingInfo): + + def get_hf_processor(self) -> MolmoProcessorWrapper: + processor = self.ctx.get_hf_processor() + return MolmoProcessorWrapper(processor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[MolmoProcessorWrapper], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_width, + image_height=image_height, + ) + pooling_size = processor.pooling_size + + base_image_input_size = processor.base_image_input_size + base_image_input_d = processor.image_patch_size + + crop_patches = base_image_input_size[0] // base_image_input_d + + per_row = ncols // pooling_size + 1 + joint = per_row * (nrows // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + + return resize + joint + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + tilings = get_candidate_tilings(processor.max_crops) + base_h, base_w = processor.base_image_input_size + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in tilings: + width, height = base_w * wr, base_h * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + +class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + processor = self.info.get_hf_processor() + + # Apply the chat template to the tokens + tokens = processor.processor.get_tokens_input( # type: ignore + self.info.get_tokenizer().decode(prompt_tokens), + message_format=processor.message_format, + always_start_with_space=processor.always_start_with_space, + ) + + processed_data = self.info.ctx.call_hf_processor( + processor, # type: ignore + dict(tokens=tokens), + ) + prompt_ids, = processed_data.pop("input_ids").tolist() + + return prompt_ids + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_images = len(num_crops) + + return dict( + images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_masks=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + feat_is_patch=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + embed_is_patch=MultiModalFieldConfig.shared("image", num_images), + num_crops=MultiModalFieldConfig.batched("image"), + img_patch_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + + image_token_length_w = processor.image_token_length_w + image_token_length_h = processor.image_token_length_h + pooling_size = processor.pooling_size + + user_str = "User:" + if processor.always_start_with_space: + user_str = " " + user_str + + user_tokens = tokenizer.encode(user_str, add_special_tokens=False) + + img_patch_id = processor.image_patch_id + img_col_id = processor.im_col_id + img_start_id = processor.im_start_id + img_end_id = processor.im_end_id + + extra_row = [img_patch_id] * image_token_length_w + [img_col_id] + extra_joint = ([img_start_id] + extra_row * image_token_length_h + + [img_end_id]) + + def get_replacement_molmo(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = processor.get_patches_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + + joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) + + [img_col_id]) + joint = ([img_start_id] + joint_row * + ((nrows + 1) // pooling_size) + [img_end_id]) + + image_tokens = extra_joint + joint + + return PromptReplacementDetails( + full=image_tokens + user_tokens, + features=image_tokens, + ) + + return [ + PromptReplacement( + modality="image", + target=user_str, + replacement=get_replacement_molmo, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, + info=MolmoProcessingInfo, + dummy_inputs=MolmoDummyInputsBuilder) class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): hf_to_vllm_mapper = WeightsMapper( @@ -1202,6 +1473,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, quant_config) self.model = MolmoModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.img_patch_id = None if self.config.weight_tying: self.lm_head = self.model.transformer.wte @@ -1224,85 +1496,143 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) - image_masks = kwargs.pop("image_masks", None) - image_start_end = kwargs.pop("image_start_end", None) if images is None: return None - image_input_idx = kwargs.pop("image_input_idx", None) - seq_len = kwargs.pop("seq_len", None) - if image_input_idx is None: - raise ValueError("image_input_idx is required for Molmo model.") - if seq_len is None: - raise ValueError("seq_len is required for Molmo model.") - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len) + if not isinstance(images, (torch.Tensor, list)): + raise ValueError("Incorrect type of images. " + f"Got type: {type(images)}") + + image_masks = kwargs.pop("image_masks", None) + if not (image_masks is None or isinstance(image_masks, + (torch.Tensor, list))): + raise ValueError("Incorrect type of image_masks. " + f"Got type: {type(image_masks)}") + + feat_is_patch = kwargs.pop("feat_is_patch", None) + if not isinstance(feat_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of feat_is_patch. " + f"Got type: {type(feat_is_patch)}") + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + num_crops = kwargs.pop("num_crops", None) + if not isinstance(num_crops, torch.Tensor): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + + img_patch_id = kwargs.pop("img_patch_id", None) + if not isinstance(img_patch_id, torch.Tensor): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + self.img_patch_id = img_patch_id.flatten().unique().item() return MolmoImageInputs( images=images, - image_input_idx=image_input_idx, - seq_len=seq_len, image_masks=image_masks, - image_start_end=image_start_end, + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, ) def _process_image_input( self, image_input: MolmoImageInputs, - ) -> torch.Tensor: - - image_features = self.vision_backbone( - images=image_input["images"], - image_masks=image_input["image_masks"], - ) + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if isinstance(image_input["images"], list): + # Call the vision backbone on the whole batch at once + images_flat = flatten_bn(image_input["images"], concat=True) + image_masks_flat = (None if (image_masks := + image_input["image_masks"]) is None + else flatten_bn(image_masks, concat=True)) + + image_features_flat = self.vision_backbone( + images=images_flat.unsqueeze(0), + image_masks=(None if image_masks_flat is None else + image_masks_flat.unsqueeze(0)), + ).squeeze(0) + + # Reconstruct the batch dimension + image_features = image_features_flat.split( + image_input["num_crops"].sum(-1).tolist()) + else: + image_features = self.vision_backbone( + images=image_input["images"], + image_masks=image_input["image_masks"], + ) return image_features + def _get_mm_embeds( + self, + features: torch.Tensor, # Shape: (num_crop, num_patch, d) + feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) + num_crops: torch.Tensor, # Shape: (num_images,) + embed_is_patch: torch.Tensor, # Shape: (num_embeds,) + ) -> list[torch.Tensor]: + """ + Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + Note: + The original code only considers patch tokens as feature + tokens, but our processor considers all image-related tokens + as feature tokens because the feature tokens need to be + consecutive in `input_ids`. + + Example: + A simplified example for one item in the batch: + + .. code-block:: + + Embedding tokens (from HF processor): + [ ] + + embed_is_patch (from HF processor): + [ False True True False True True False False ] + + Encoder outputs (from model): + [ p1 p2 0 p3 p4 0 ] + + feat_is_patch (from HF processor): + [ True True False True True False ] + + The resulting embedding tensor is: + [ nan p1 p2 nan p3 p4 nan nan ] + """ + num_crops_per_image = num_crops.tolist() + feats_per_image = features.split(num_crops_per_image) + f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) + + _, _, embed_dim = features.shape + (num_embeds, ) = embed_is_patch.shape + + embeds_in_batch = list[torch.Tensor]() + for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image): + embeds = feats.new_full((num_embeds, embed_dim), torch.nan) + embeds[embed_is_patch] = feats[f_is_patch] + embeds_in_batch.append(embeds) + + return embeds_in_batch + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None + image_features = self._process_image_input(image_input) - image_input_idx = image_input["image_input_idx"] - seq_len = image_input["seq_len"] - batch_size, num_image, num_patch = image_features.shape[:3] - assert image_input_idx.shape == (batch_size, num_image, num_patch) - - # insert the image feature into the embedding. - image_features = image_features.view(batch_size, num_image * num_patch, - -1) - image_input_idx = image_input_idx.view(batch_size, - num_image * num_patch) - - valid = image_input_idx >= 0 - image_features = image_features * valid[:, :, None].to( - image_features.dtype) - image_features = image_features.view( - batch_size * num_image * num_patch, -1).contiguous() - - image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) - offset = torch.cat([seq_len.new_zeros(1), - seq_len.cumsum(dim=0)[:-1]], - dim=0)[:, None] - image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) - image_input_idx = image_input_idx.flatten()[:, None] - mat = image_input_idx == torch.arange( - seq_len.sum().item(), device=image_features.device)[None, :] - mat = mat.to(image_features.dtype) - - # Note: In this original implementation from AI2, the final - # vision_embeddings will be always be the same length - # of input embeddings. - vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) - - # Split by the sizes of the input sequences. For each full embedding, - # extract the actual vision embeddings to be merged. - vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) - for i in range(len(vision_embeddings)): - start, end = image_input['image_start_end'][i] - vision_embeddings[i] = vision_embeddings[i][start:end] - - return vision_embeddings + + return [ + self._get_mm_embeds(*args) for args in zip( + image_features, + image_input["feat_is_patch"], + image_input["num_crops"], + image_input["embed_is_patch"], + ) + ] def get_input_embeddings( self, @@ -1311,11 +1641,20 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: + assert self.img_patch_id is not None + + # Extract the patch tokens scattered in _get_mm_embeds + patch_embeddings = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) + inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, [ - DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, - DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID - ]) + input_ids, + inputs_embeds, + cast(NestedTensors, patch_embeddings), + self.img_patch_id, + ) return inputs_embeds def forward( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 25ca8d1e7..e93fa24a6 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -353,17 +353,17 @@ class MultiModalFieldConfig: Example: - .. code-block:: + .. code-block:: - Input: - Data: [[AAAA] - [BBBB] - [CCCC]] + Input: + Data: [[AAAA] + [BBBB] + [CCCC]] - Output: - Element 1: [AAAA] - Element 2: [BBBB] - Element 3: [CCCC] + Output: + Element 1: [AAAA] + Element 2: [BBBB] + Element 3: [CCCC] """ return MultiModalFieldConfig( field=MultiModalBatchedField(), @@ -384,18 +384,18 @@ class MultiModalFieldConfig: Example: - .. code-block:: - - Given: - slices: [slice(0, 3), slice(3, 7), slice(7, 9)] + .. code-block:: + + Given: + slices: [slice(0, 3), slice(3, 7), slice(7, 9)] - Input: - Data: [AAABBBBCC] + Input: + Data: [AAABBBBCC] - Output: - Element 1: [AAA] - Element 2: [BBBB] - Element 3: [CC] + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] """ return MultiModalFieldConfig( field=MultiModalFlatField(slices=slices), @@ -416,18 +416,18 @@ class MultiModalFieldConfig: Example: - .. code-block:: - - Given: - size_per_item: [3, 4, 2] + .. code-block:: + + Given: + size_per_item: [3, 4, 2] - Input: - Data: [AAABBBBCC] + Input: + Data: [AAABBBBCC] - Output: - Element 1: [AAA] - Element 2: [BBBB] - Element 3: [CC] + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] See also: :func:`MultiModalFieldConfig.flat` @@ -456,19 +456,19 @@ class MultiModalFieldConfig: Example: - .. code-block:: - - Given: - batch_size: 4 + .. code-block:: + + Given: + batch_size: 4 - Input: - Data: [XYZ] + Input: + Data: [XYZ] - Output: - Element 1: [XYZ] - Element 2: [XYZ] - Element 3: [XYZ] - Element 4: [XYZ] + Output: + Element 1: [XYZ] + Element 2: [XYZ] + Element 3: [XYZ] + Element 4: [XYZ] """ return MultiModalFieldConfig( field=MultiModalSharedField(batch_size), diff --git a/vllm/utils.py b/vllm/utils.py index 6a41afff8..79981fa09 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,8 +33,7 @@ from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, Generic, Iterator, List, Literal, - NamedTuple, Optional, Tuple, Type, TypeVar, Union, - overload) + NamedTuple, Optional, Tuple, Type, TypeVar, Union) from uuid import uuid4 import cloudpickle @@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], """A nested JSON structure where the leaves need not be JSON-serializable.""" -@overload -def json_map_leaves( - func: Callable[[T], U], - value: Dict[str, JSONTree[T]], -) -> Dict[str, JSONTree[U]]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: List[JSONTree[T]], -) -> List[JSONTree[U]]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: Tuple[JSONTree[T], ...], -) -> Tuple[JSONTree[U], ...]: - ... - - -@overload -def json_map_leaves( - func: Callable[[T], U], - value: JSONTree[T], -) -> JSONTree[U]: - ... - - def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: if isinstance(value, dict): return {k: json_map_leaves(func, v) for k, v in value.items()} -- GitLab From 2092a6fa7d58e398bc1a234edcef62ce555faafe Mon Sep 17 00:00:00 2001 From: Aoyu Date: Thu, 13 Feb 2025 20:35:18 +0800 Subject: [PATCH 119/253] [V1][Core] Add worker_base for v1 worker (#12816) Signed-off-by: Aoyu Signed-off-by: youkaichao Co-authored-by: Aoyu Co-authored-by: youkaichao --- vllm/utils.py | 43 +++++++++++++++++++++ vllm/v1/worker/gpu_worker.py | 28 +++++--------- vllm/v1/worker/worker_base.py | 63 +++++++++++++++++++++++++++++++ vllm/worker/worker_base.py | 71 +++++++++++++++++++---------------- 4 files changed, 153 insertions(+), 52 deletions(-) create mode 100644 vllm/v1/worker/worker_base.py diff --git a/vllm/utils.py b/vllm/utils.py index 79981fa09..1d7fbd4a7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2220,3 +2220,46 @@ def import_pynvml(): """ import vllm.third_party.pynvml as pynvml return pynvml + + +def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: + continue + src = inspect.getsource(attr_func) + if "NotImplementedError" in src: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, '__init__', wrapped_init) + return cls diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index beedca05c..8f2ffe5f1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -28,7 +29,7 @@ if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput -class Worker: +class Worker(WorkerBase): def __init__( self, @@ -39,23 +40,11 @@ class Worker: is_driver_worker: bool = False, ): - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -126,7 +115,8 @@ class Worker: set_random_seed(self.model_config.seed) # Construct the model runner - self.model_runner = GPUModelRunner(self.vllm_config, self.device) + self.model_runner: GPUModelRunner = GPUModelRunner( + self.vllm_config, self.device) def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 000000000..bc7e76c38 --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 + +logger = init_logger(__name__) + + +class WorkerBase(WorkerBaseV0): + """ + Abstract class for v1 worker, mainly define some methods for v1. + For methods shared by v0 and v1, define them in v0 WorkerBase + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + """ + Initialize common worker components. + + Args: + vllm_config: Complete vLLM configuration + local_rank: Local device index + rank: Global rank in distributed setup + distributed_init_method: Distributed initialization method + is_driver_worker: Whether this worker handles driver + responsibilities + """ + # Configuration storage + super().__init__(vllm_config=vllm_config) + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + # Device and model state + self.device: Optional[torch.device] = None + self.model_runner: Optional[nn.Module] = None + + def get_kv_cache_spec(self) -> KVCacheSpec: + """Get specifications for KV cache implementation.""" + raise NotImplementedError + + def compile_or_warm_up_model(self) -> None: + """Prepare model for execution through compilation/warmup.""" + raise NotImplementedError + + def check_health(self) -> None: + """Basic health check (override for device-specific checks).""" + return diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbf..83fcf0865 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -3,7 +3,7 @@ import dataclasses import os import time -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import cloudpickle @@ -19,7 +19,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, - update_environment_variables) + update_environment_variables, + warn_for_unimplemented_methods) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -27,7 +28,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, logger = init_logger(__name__) -class WorkerBase(ABC): +@warn_for_unimplemented_methods +class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to communicate request metadata to other workers. @@ -53,35 +55,31 @@ class WorkerBase(ABC): from vllm.platforms import current_platform self.current_platform = current_platform - @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device memory allocations. """ raise NotImplementedError - @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - @abstractmethod def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache with the given size in blocks. """ raise NotImplementedError + def get_model(self) -> nn.Module: + raise NotImplementedError + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError + def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. @@ -94,40 +92,43 @@ class WorkerBase(ABC): if output is None: return None - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. - @abstractmethod - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ raise NotImplementedError - @abstractmethod def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ raise NotImplementedError - @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError - @abstractmethod def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def list_loras(self) -> Set[int]: raise NotImplementedError + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + class DelegateWorkerBase(WorkerBase): """ @@ -156,6 +157,10 @@ class DelegateWorkerBase(WorkerBase): num_cpu_blocks: int) -> None: self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def load_model(self) -> None: + """Load model onto target device.""" + self.worker.load_model() + def get_model(self) -> nn.Module: return self.worker.get_model() -- GitLab From 02ed8a1fbe41e3ad1bc04fd29b754facd28e329f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=83?= Date: Thu, 13 Feb 2025 22:17:57 +0800 Subject: [PATCH 120/253] [Misc] Qwen2.5-VL Optimization (#13155) --- vllm/model_executor/models/qwen2_5_vl.py | 61 ++++++++++-------------- vllm/model_executor/models/qwen2_vl.py | 37 ++++++++------ 2 files changed, 47 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d4c48dbda..6aec99b3f 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -45,6 +45,7 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -271,8 +272,13 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN + q = apply_rotary_pos_emb_vision(q, + rotary_pos_emb, + use_flash_attn=use_flash_attn) + k = apply_rotary_pos_emb_vision(k, + rotary_pos_emb, + use_flash_attn=use_flash_attn) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( @@ -296,20 +302,23 @@ class Qwen2_5_VisionAttention(nn.Module): "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: - seq_length = q.size(1) - q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + # Execute attention entry by entry for speed & less VRAM. + outputs = [] for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) - context_layer = rearrange(output, "b h s d -> b s h d ") + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -327,25 +336,6 @@ class Qwen2_5_VisionAttention(nn.Module): return output -class Qwen2RMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -516,8 +506,7 @@ class Qwen2_5_VisionTransformer(nn.Module): hidden_size=self.hidden_size, ) - # NOTE: We use torch native RMSNorm here for precision purposes. - norm_layer = partial(Qwen2RMSNorm, eps=norm_eps) + norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d3294a4d4..961f53cef 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: + freqs: torch.Tensor, + use_flash_attn=False) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() - output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + apply_rotary_emb = apply_rotary_emb_torch + if use_flash_attn: + from flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) return output @@ -336,20 +340,23 @@ class Qwen2VisionAttention(nn.Module): "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: - seq_length = q.size(1) - q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + # Execute attention entry by entry for speed & less VRAM. + outputs = [] for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) - context_layer = rearrange(output, "b h s d -> b s h d ") + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask -- GitLab From 1bc3b5e71b26d82e655f09d4d1733b7706ea3ff6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 13 Feb 2025 22:19:15 +0800 Subject: [PATCH 121/253] [VLM] Separate text-only and vision variants of the same model architecture (#13157) --- docs/source/models/supported_models.md | 17 +- examples/offline_inference/vision_language.py | 3 + .../vision_language_multi_image.py | 5 +- tests/distributed/test_pipeline_parallel.py | 171 ++-- .../vision_language/test_models.py | 11 +- .../vision_language/vlm_utils/core.py | 62 +- .../vision_language/vlm_utils/types.py | 10 +- tests/models/registry.py | 37 +- tests/models/test_initialization.py | 3 +- vllm/model_executor/models/chatglm.py | 420 ++------- .../models/glm4_vision_encoder.py | 312 ------- vllm/model_executor/models/glm4v.py | 662 ++++++++++++++ vllm/model_executor/models/qwen.py | 856 +----------------- vllm/model_executor/models/qwen_vl.py | 794 ++++++++++++++++ vllm/model_executor/models/registry.py | 9 +- 15 files changed, 1729 insertions(+), 1643 deletions(-) delete mode 100644 vllm/model_executor/models/glm4_vision_encoder.py create mode 100644 vllm/model_executor/models/glm4v.py create mode 100644 vllm/model_executor/models/qwen_vl.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 86b746178..e498efc22 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -699,10 +699,10 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ -- * `DeepseekVLV2ForCausalLM` +- * `DeepseekVLV2ForCausalLM`^ * DeepSeek-VL2 * T + I+ - * `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note) + * `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. * * ✅︎ * ✅︎ @@ -713,10 +713,10 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ -- * `ChatGLMModel` +- * `GLM4VForCausalLM`^ * GLM-4V * T + I - * `THUDM/glm-4v-9b` etc. + * `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. * ✅︎ * ✅︎ * ✅︎ @@ -825,7 +825,7 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ -- * `QWenLMHeadModel` +- * `QwenVLForConditionalGeneration`^ * Qwen-VL * T + IE+ * `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. @@ -862,13 +862,12 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ ::: +^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +    • For example, to use DeepSeek-VL2 series models: +      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` E Pre-computed embeddings can be inputted for this modality. + Multiple items can be inputted per text prompt for this modality. -:::{note} -To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM. -::: - :::{note} H2O-VL series models will be available in V1 once we support backends other than FlashAttention. ::: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 9a4183106..b9963669a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -105,7 +105,9 @@ def run_glm4v(question: str, modality: str): max_num_seqs=2, trust_remote_code=True, enforce_eager=True, + hf_overrides={"architectures": ["GLM4VForCausalLM"]}, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ {question}<|assistant|>" @@ -495,6 +497,7 @@ def run_qwen_vl(question: str, modality: str): trust_remote_code=True, max_model_len=1024, max_num_seqs=2, + hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 8d2172a60..1a5ea0c70 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -77,7 +77,7 @@ def load_deepseek_vl2(question: str, image_urls: List[str]): ) -def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: +def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-2b" llm = LLM( @@ -302,6 +302,7 @@ def load_qwen_vl_chat(question: str, trust_remote_code=True, max_model_len=1024, max_num_seqs=2, + hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, limit_mm_per_prompt={"image": len(image_urls)}, ) placeholders = "".join(f"Picture {i}: \n" @@ -452,7 +453,7 @@ def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, "deepseek_vl_v2": load_deepseek_vl2, - "h2ovl_chat": load_h2onvl, + "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, "mllama": load_mllama, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 6a54fb74b..eb9cd5db9 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -6,6 +6,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node all workers in a node other than the head node, which can cause the test to fail. """ +import json import os from dataclasses import dataclass from typing import List, Literal, NamedTuple, Optional @@ -15,6 +16,7 @@ import pytest from vllm.config import TaskOption from vllm.logger import init_logger +from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, fork_new_process_for_each_test logger = init_logger("test_pipeline_parallel") @@ -31,10 +33,7 @@ class ParallelSetup(NamedTuple): class PPTestOptions(NamedTuple): multi_node_only: bool - trust_remote_code: bool - tokenizer_mode: Optional[str] load_format: Optional[str] = None - hf_overrides: Optional[str] = None @dataclass @@ -64,10 +63,7 @@ class PPTestSettings: pp_base: int = 2, multi_node_only: bool = False, task: TaskOption = "auto", - trust_remote_code: bool = False, - tokenizer_mode: Optional[str] = None, load_format: Optional[str] = None, - hf_overrides: Optional[str] = None, ): return PPTestSettings( parallel_setups=[ @@ -97,10 +93,7 @@ class PPTestSettings: vllm_major_versions=["0", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + load_format=load_format), ) @staticmethod @@ -110,10 +103,7 @@ class PPTestSettings: pp_base: int = 2, task: TaskOption = "auto", multi_node_only: bool = False, - trust_remote_code: bool = False, - tokenizer_mode: Optional[str] = None, load_format: Optional[str] = None, - hf_overrides: Optional[str] = None, ): return PPTestSettings( parallel_setups=[ @@ -126,19 +116,16 @@ class PPTestSettings: vllm_major_versions=["0"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + load_format=load_format), ) - def iter_params(self, model_name: str): + def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: for backend, vllm_major_version in zip(self.distributed_backends, self.vllm_major_versions): - yield (model_name, parallel_setup, backend, vllm_major_version, + yield (model_id, parallel_setup, backend, vllm_major_version, self.task, opts) @@ -150,16 +137,16 @@ TEXT_GENERATION_MODELS = { # [Decoder-only] # Uses Llama # "BAAI/AquilaChat-7B": PPTestSettings.fast(), - "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501 - "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True), - "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "baichuan-inc/Baichuan-7B": PPTestSettings.fast(), + "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(), - "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True), - "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501 - "databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8), - "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True), + "THUDM/chatglm3-6b": PPTestSettings.fast(), + "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"), + "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"), + "Deci/DeciLM-7B-instruct": PPTestSettings.fast(), "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), - "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(), "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(), "tiiuae/falcon-7b": PPTestSettings.fast(), "google/gemma-2b": PPTestSettings.fast(), @@ -172,36 +159,36 @@ TEXT_GENERATION_MODELS = { "ibm/PowerMoE-3b": PPTestSettings.fast(), # Uses Llama # "internlm/internlm-chat-7b": PPTestSettings.fast(), - "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), + "internlm/internlm2-chat-7b": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), - "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), - "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), + "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), + "openbmb/MiniCPM3-4B": PPTestSettings.fast(), # Uses Llama # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), "state-spaces/mamba-130m-hf": PPTestSettings.fast(), - "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), + "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), # noqa: E501 "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(), "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), - "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), + "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(), "adept/persimmon-8b-chat": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(), - "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 - "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501 - "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), + "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(), + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501 + "Qwen/Qwen-7B-Chat": PPTestSettings.fast(), "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), - "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` - # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), + # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(), # [Encoder-only] # TODO: Implement PP # "facebook/bart-base": PPTestSettings.fast(), @@ -211,7 +198,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), - "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 + "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"), } MULTIMODAL_MODELS = { @@ -219,20 +206,20 @@ MULTIMODAL_MODELS = { "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), "facebook/chameleon-7b": PPTestSettings.fast(), "adept/fuyu-8b": PPTestSettings.fast(), - "THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True), - "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True), + "THUDM/glm-4v-9b": PPTestSettings.fast(), + "OpenGVLab/InternVL2-1B": PPTestSettings.fast(), "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(), "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(), "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(), "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), - "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True), - "allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True), - "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 - "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501 - "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), + "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), + "allenai/Molmo-7B-D-0924": PPTestSettings.fast(), + "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(), + "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), + "Qwen/Qwen-VL-Chat": PPTestSettings.fast(), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), - "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(), # [Encoder-decoder] # TODO: Implement PP # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), @@ -258,7 +245,7 @@ TEST_MODELS = [ def _compare_tp( - model_name: str, + model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, vllm_major_version: str, @@ -267,6 +254,7 @@ def _compare_tp( num_gpus_available: int, *, method: Literal["generate", "encode"], + is_multimodal: bool, ): ( tp_size, @@ -274,13 +262,32 @@ def _compare_tp( eager_mode, chunked_prefill, ) = parallel_setup - ( - multi_node_only, - trust_remote_code, - tokenizer_mode, - load_format, - hf_overrides, - ) = test_options + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_layers": 1, + "num_hidden_layers": 1, + "num_experts": 2, + "num_experts_per_tok": 2, + "num_local_experts": 2, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") @@ -312,7 +319,7 @@ def _compare_tp( if load_format: common_args.extend(["--load-format", load_format]) if hf_overrides: - common_args.extend(["--hf-overrides", hf_overrides]) + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill if distributed_backend == "ray" and (vllm_major_version == "1" @@ -355,11 +362,7 @@ def _compare_tp( ] try: - compare_two_settings(model_name, - pp_args, - tp_args, - pp_env, - method=method) + compare_two_settings(model_id, pp_args, tp_args, pp_env, method=method) except Exception: if pp_env is None: raise @@ -369,17 +372,16 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "vllm_major_version", "task", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), [ - params for model_name, settings in TEXT_GENERATION_MODELS.items() - for params in settings.iter_params(model_name) - if model_name in TEST_MODELS + params for model_id, settings in TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) @fork_new_process_for_each_test def test_tp_language_generation( - model_name: str, + model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, vllm_major_version: str, @@ -387,28 +389,28 @@ def test_tp_language_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_name, + _compare_tp(model_id, parallel_setup, distributed_backend, vllm_major_version, task, test_options, num_gpus_available, - method="generate") + method="generate", + is_multimodal=False) @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "vllm_major_version", "task", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), [ - params for model_name, settings in EMBEDDING_MODELS.items() - for params in settings.iter_params(model_name) - if model_name in TEST_MODELS + params for model_id, settings in EMBEDDING_MODELS.items() + for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) @fork_new_process_for_each_test def test_tp_language_embedding( - model_name: str, + model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, vllm_major_version: str, @@ -416,28 +418,28 @@ def test_tp_language_embedding( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_name, + _compare_tp(model_id, parallel_setup, distributed_backend, vllm_major_version, task, test_options, num_gpus_available, - method="encode") + method="encode", + is_multimodal=False) @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "vllm_major_version", "task", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), [ - params for model_name, settings in MULTIMODAL_MODELS.items() - for params in settings.iter_params(model_name) - if model_name in TEST_MODELS + params for model_id, settings in MULTIMODAL_MODELS.items() + for params in settings.iter_params(model_id) if model_id in TEST_MODELS ], ) @fork_new_process_for_each_test def test_tp_multimodal_generation( - model_name: str, + model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, vllm_major_version: str, @@ -445,11 +447,12 @@ def test_tp_multimodal_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_name, + _compare_tp(model_id, parallel_setup, distributed_backend, vllm_major_version, task, test_options, num_gpus_available, - method="generate") + method="generate", + is_multimodal=True) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 4ed61cfc9..2c66edb53 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -155,10 +155,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], - marks=[pytest.mark.skipif( - TRANSFORMERS_VERSION < "4.49.0", - reason="HF model requires transformers>=4.49.0", - ), pytest.mark.core_model, pytest.mark.cpu_model], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), #### Extended model tests "aria": VLMTestInfo( @@ -215,7 +212,6 @@ VLM_TEST_SETTINGS = { "cherry_blossom": "\nPlease infer the season with reason in details.", # noqa: E501 }), multi_image_prompt="image_1:\nimage_2:\nWhich image can we see the car and the tower?", # noqa: E501 - vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501 patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, postprocess_inputs=model_utils.cast_dtype_post_processor("images"), hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, @@ -240,7 +236,7 @@ VLM_TEST_SETTINGS = { num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], ), - "glm4": VLMTestInfo( + "glm4v": VLMTestInfo( models=["THUDM/glm-4v-9b"], test_type=VLMTestType.IMAGE, prompt_formatter=identity, @@ -351,7 +347,6 @@ VLM_TEST_SETTINGS = { postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), - vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501 get_stop_token_ids=lambda tok: [128009], auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, @@ -437,7 +432,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForVision2Seq, marks=[large_gpu_mark(min_gb=48)], ), - "qwen": VLMTestInfo( + "qwen_vl": VLMTestInfo( models=["Qwen/Qwen-VL"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=identity, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 0aed26769..f2260f567 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -4,12 +4,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from PIL.Image import Image -from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase +from transformers import BatchEncoding from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption +from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import HfRunner, VllmRunner +from ....registry import HF_EXAMPLE_MODELS from .types import RunnerOutput @@ -31,10 +33,8 @@ def run_test( use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], - List[int]]], + get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]], stop_str: Optional[List[str]], - tokenizer_mode: str, limit_mm_per_prompt: Dict[str, int], vllm_runner_kwargs: Optional[Dict[str, Any]], hf_model_kwargs: Optional[Dict[str, Any]], @@ -48,7 +48,10 @@ def run_test( """Modality agnostic test test executor for comparing HF/vLLM outputs.""" # In the case of embeddings, vLLM takes separate input tensors vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs - tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") vllm_outputs_per_mm = [] hf_outputs_per_mm = [] @@ -57,17 +60,19 @@ def run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_kwargs: Dict[str, Any] = {} - if get_stop_token_ids is not None: - vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) - if stop_str: - vllm_kwargs["stop"] = stop_str - if vllm_runner_kwargs is None: - vllm_runner_kwargs = {} + vllm_runner_kwargs_: Dict[str, Any] = {} + if model_info.tokenizer: + vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer + if model_info.tokenizer_mode: + vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode + if model_info.hf_overrides: + vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides + + if vllm_runner_kwargs: + vllm_runner_kwargs_.update(vllm_runner_kwargs) with vllm_runner(model, - tokenizer_mode=tokenizer_mode, max_model_len=max_model_len, max_num_seqs=max_num_seqs, dtype=dtype, @@ -76,7 +81,15 @@ def run_test( distributed_executor_backend=distributed_executor_backend, enforce_eager=enforce_eager, task=task, - **vllm_runner_kwargs) as vllm_model: + **vllm_runner_kwargs_) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + + vllm_kwargs: Dict[str, Any] = {} + if get_stop_token_ids is not None: + vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) + if stop_str: + vllm_kwargs["stop"] = stop_str + for prompts, media in vllm_inputs: vllm_kwargs[runner_mm_key] = media vllm_output = vllm_model.generate_greedy_logprobs( @@ -93,16 +106,19 @@ def run_test( if patch_hf_runner is not None: hf_model = patch_hf_runner(hf_model) - # Some models need to explicitly pass the eos_token_id off the tokenizer or - # processor for a good comparison; currently assume processor/tokenizer - # agree on the EOS, and pull it off the tokenizer if requested. - hf_kwargs = {} - if use_tokenizer_eos: - hf_kwargs["eos_token_id"] = tokenizer.eos_token_id - if stop_str: - hf_kwargs["stop_strings"] = stop_str - with hf_model, torch.no_grad(): + tokenizer = hf_model.tokenizer + + # Some models need to explicitly pass the eos_token_id off the tokenizer + # or processor for a good comparison; + # currently assume processor/tokenizer agree on the EOS, and pull it off + # the tokenizer if requested. + hf_kwargs = {} + if use_tokenizer_eos: + hf_kwargs["eos_token_id"] = tokenizer.eos_token_id + if stop_str: + hf_kwargs["stop_strings"] = stop_str + for prompts, media in inputs: hf_kwargs[runner_mm_key] = media hf_output = hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index ae3b9d59b..ecb86609c 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -8,12 +8,12 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional, import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import (AutoModelForCausalLM, BatchEncoding, - PreTrainedTokenizerBase) +from transformers import AutoModelForCausalLM, BatchEncoding from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption from vllm.sequence import SampleLogprobs +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import identity from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets @@ -100,8 +100,7 @@ class VLMTestInfo(NamedTuple): vllm_runner_kwargs: Optional[Dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], - List[int]]] = None + get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]] = None # Optional list of strings to stop generation, useful when stop tokens are # not special tokens in the tokenizer stop_str: Optional[List[str]] = None @@ -156,8 +155,6 @@ class VLMTestInfo(NamedTuple): marks: Optional[List[MarkDecorator]] = None - tokenizer_mode: str = "auto" - def get_non_parametrized_runner_kwargs(self): """Returns a dictionary of expandable kwargs for items that are used in all test types, which are NOT used when creating the parametrized @@ -180,7 +177,6 @@ class VLMTestInfo(NamedTuple): "hf_model_kwargs": self.hf_model_kwargs, "stop_str": self.stop_str, "patch_hf_runner": self.patch_hf_runner, - "tokenizer_mode": self.tokenizer_mode } diff --git a/tests/models/registry.py b/tests/models/registry.py index 66a487ca6..9c0e6b337 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -104,7 +104,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), - # ChatGLMModel supports multimodal + "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", + trust_remote_code=True), "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", trust_remote_code=True), "Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501 @@ -138,7 +139,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct", trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), - "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"), + "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", + extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"), "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), @@ -167,7 +169,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), - # QWenLMHeadModel supports multimodal + "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", + trust_remote_code=True), "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", @@ -232,18 +235,19 @@ _MULTIMODAL_EXAMPLE_MODELS = { "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 - "ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b", - extras={"text_only": "THUDM/chatglm3-6b"}, - trust_remote_code=True), - "ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b", - is_available_online=False), "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"), + "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", + trust_remote_code=True, + hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 + "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", + extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", + extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 trust_remote_code=True), - "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501 + "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 + {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501 "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 @@ -253,21 +257,24 @@ _MULTIMODAL_EXAMPLE_MODELS = { hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), - "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", + "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", + extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), - "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501 + "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 + extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 tokenizer_mode="mistral"), - "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat", - extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501 - trust_remote_code=True), + "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", + extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501 + trust_remote_code=True, + hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 64928a65d..c58c63723 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -18,8 +18,7 @@ def test_can_initialize(model_arch): # Avoid OOM def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_vl_v2": - hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]}) + hf_config.update(model_info.hf_overrides) if hasattr(hf_config, "text_config"): text_config: PretrainedConfig = hf_config.text_config diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 153c85cfb..26b4a95c5 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -1,20 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 - # Adapted from -# https://github.com/THUDM/CogAgent -"""Inference-only CogAgent model compatible with THUDM weights.""" -from argparse import Namespace -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights.""" +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import LayerNorm -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from transformers import PreTrainedTokenizer, TensorType -from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig @@ -31,204 +23,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel -from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BatchFeature, - MultiModalFieldConfig, - PromptReplacement) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) - - -class GLMImagePixelInputs(TypedDict): - pixel_values: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" - - -class GLM4VProcessor: - """ - This model doesn't define its own HF processor, - so we implement our own one here. - - """ - - def __init__( - self, - config: ChatGLMConfig, - tokenizer: PreTrainedTokenizer, - ) -> None: - super().__init__() - - self.config = config - self.tokenizer = tokenizer - - if vision_config := getattr(config, "vision_config", None): - image_size = vision_config["image_size"] - - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) - else: - self.image_transform = None - - def __call__( - self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchFeature: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - text_inputs = self.tokenizer(text) - if len(images) == 0: - image_inputs = {} - else: - if self.image_transform is None: - raise ValueError("This model does not support image inputs") - - pixel_values = [self.image_transform(image) for image in images] - image_inputs = {"pixel_values": torch.stack(pixel_values)} - - return BatchFeature( - { - **text_inputs, - **image_inputs, - }, - tensor_type=return_tensors, - ) - - -class GLM4VProcessingInfo(BaseProcessingInfo): - - def get_tokenizer(self): - tokenizer = self.ctx.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer) - return tokenizer - - def get_hf_config(self): - return self.ctx.get_hf_config(ChatGLMConfig) - - def get_hf_processor(self) -> GLM4VProcessor: - return GLM4VProcessor( - self.get_hf_config(), - self.get_tokenizer(), - ) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_feature_tokens()} - - def get_num_image_tokens(self) -> int: - hf_config = self.get_hf_config() - if not (vision_config := getattr(hf_config, "vision_config", None)): - return 0 - - image_size = vision_config["image_size"] - patch_size = vision_config["patch_size"] - grid_length = image_size // patch_size // 2 - return grid_length * grid_length - - def get_num_image_feature_tokens(self) -> int: - # EVA2CLIPModel has embeddings for boi and eoi tokens as well - return self.get_num_image_tokens() + 2 - - -class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.info.get_hf_config() - if not (vision_config := getattr(hf_config, "vision_config", None)): - return ProcessorInputs(prompt_text="", mm_data={}) - - target_width = target_height = vision_config["image_size"] - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" - - return ProcessorInputs( - prompt_text=base_text * num_images, - mm_data=mm_data, - ) - - -class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image")) - - def _get_prompt_replacements( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - hf_config = self.info.get_hf_config() - if not hasattr(hf_config, "vision_config"): - return [] - - boi_token_id = hf_config.boi_token_id - image_token_id = hf_config.pad_token_id - eoi_token_id = hf_config.eoi_token_id - - def get_replacement(item_idx: int): - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [image_token_id] * num_image_tokens - - return [boi_token_id] + image_tokens + [eoi_token_id] - - return [ - PromptReplacement( - modality="image", - target=[boi_token_id, image_token_id, eoi_token_id], - replacement=get_replacement, - ), - ] + maybe_prefix) class GLMAttention(nn.Module): @@ -489,7 +291,7 @@ class GLMTransformer(nn.Module): position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( @@ -498,8 +300,12 @@ class GLMTransformer(nn.Module): kv_cache=kv_caches[i - self.start_layer], attn_metadata=attn_metadata, ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + # Final layer norm. - if get_pp_group().is_last_rank and self.post_layer_norm: + if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states @@ -534,61 +340,11 @@ class ChatGLMModel(nn.Module): quant_config=quant_config, prefix=f"{prefix}.output_layer") - vision_config_flag = getattr(config, 'vision_config', None) - if vision_config_flag is not None: - self.vision_config = Namespace(**config.vision_config) - self.vision = EVA2CLIPModel(self.config, - quant_config, - prefix=f"{prefix}.vision") - else: - self.vision = None - self.make_empty_intermediate_tensors = ( self.encoder.make_empty_intermediate_tensors) - def _parse_and_validate_image_input( - self, **kwargs: object) -> GLMImagePixelInputs: - - pixel_values = kwargs.pop("pixel_values", None) - if pixel_values is not None and self.vision is not None: - if isinstance(pixel_values, torch.Tensor): - if pixel_values.ndim > 2: - pixel_values = torch.concat(list(pixel_values)) - elif isinstance(pixel_values, list): - return torch.concat(pixel_values) - else: - raise TypeError("""pixel_values must be a torch.Tensor - or a list of torch.Tensor - """) - return GLMImagePixelInputs(pixel_values=pixel_values) - - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input["pixel_values"] is None: - return None - pixel_values = image_input["pixel_values"].to( - dtype=self.config.torch_dtype) - vision_embeddings = self.vision(pixel_values) - return vision_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.embedding(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - placeholder_token_id=[ - self.config.boi_token_id, - self.config.pad_token_id, - self.config.eoi_token_id, - ], - ) - return inputs_embeds + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) def forward( self, @@ -599,26 +355,24 @@ class ChatGLMModel(nn.Module): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - if intermediate_tensors is not None: - inputs_embeds = intermediate_tensors["hidden_states"] - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) # Run encoder. hidden_states = self.encoder( - hidden_states=inputs_embeds, + hidden_states=hidden_states, position_ids=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states def load_weights(self, weights: Iterable[Tuple[str, @@ -660,12 +414,18 @@ class ChatGLMModel(nn.Module): return loaded_params -class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): +class ChatGLMBaseModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={".word_embeddings": ""}, ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[ChatGLMModel] = ChatGLMModel, + ) -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -678,27 +438,17 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) - self.transformer = ChatGLMModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = transformer_type(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.transformer.output_layer.weight = ( self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = get_sampler() - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - **kwargs) - return hidden_states + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def compute_logits( self, @@ -722,7 +472,7 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -class ChatGLM(ChatGLMBaseModel): +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -738,82 +488,28 @@ class ChatGLM(ChatGLMBaseModel): embedding_modules = {} embedding_padding_modules = [] + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if hasattr(config, "vision_config"): + hf_overrides = {"architectures": ["GLM4VForCausalLM"]} + raise RuntimeError( + "The configuration of this model indicates that it supports " + "vision inputs, but you instantiated the text-only version " + "of this model. Please use the vision model by setting " + f"`--hf-overrides {hf_overrides!r}`") -class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): - - packed_modules_mapping = { - "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"], - "merged_proj": ["gate_proj", "dense_h_to_4h"] - } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - # vision - "fc1", - "fc2", - "merged_proj", - "linear_proj" - ] - - embedding_modules = {} - embedding_padding_modules = [] - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="transformer.encoder", - connector="transformer.vision.linear_proj", - tower_model="transformer.vision.transformer") - - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - return self.transformer.get_multimodal_embeddings(**kwargs) + super().__init__(vllm_config=vllm_config, prefix=prefix) - def get_input_embeddings( + def forward( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids, - multimodal_embeddings) - - -@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, - info=GLM4VProcessingInfo, - dummy_inputs=GLM4VDummyInputsBuilder) -class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, - SupportsMultiModal): - # Ensure that the LoRA support check passes when the class is not - # initialized, but set all these attributes to empty. - # These will be updated when an instance class is selected - packed_modules_mapping = {} - supported_lora_modules = [] - embedding_modules = {} - embedding_padding_modules = [] - - def __new__( - cls, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - config = vllm_config.model_config.hf_config - - # Initialize VL - if hasattr(config, "vision_config"): # noqa: SIM108 - instance_cls = ChatGLMV - # Initialize LLM - else: - instance_cls = ChatGLM - - # quant_config references base class members, - # so update values before init is called - cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) - cls.supported_lora_modules += instance_cls.supported_lora_modules - cls.embedding_modules.update(instance_cls.embedding_modules) - cls.embedding_padding_modules += instance_cls.embedding_padding_modules - return instance_cls(vllm_config=vllm_config, prefix=prefix) + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py deleted file mode 100644 index 2facd1353..000000000 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ /dev/null @@ -1,312 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from -# https://github.com/THUDM/GLM-4 -"""Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" -from argparse import Namespace -from typing import Optional - -import torch -from torch import nn -from torch.nn import LayerNorm - -from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - - -class PatchEmbedding(nn.Module): - - def __init__(self, config): - super().__init__() - self.proj = nn.Conv2d(config.in_channels, - config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size) - self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) - self.position_embedding = nn.Embedding(config.num_positions, - config.hidden_size) - - def forward(self, images: torch.Tensor) -> torch.Tensor: - """ - Parameters: - images : torch.Tensor - Input image tensor with shape (B, C, H, W) - - Returns: - torch.Tensor - Transformed tensor with shape (B, L, D) - """ - images = images.to(device=self.proj.weight.device, - dtype=self.proj.weight.dtype) - x = self.proj(images) - x = x.flatten(2).transpose(1, 2) - cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) - x = torch.cat((cls_token, x), dim=1) - x += self.position_embedding.weight.unsqueeze(0) - return x - - -class Attention(nn.Module): - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - super().__init__() - self.hidden_size = config.hidden_size - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_rank = config.num_heads // self.tp_size - self.head_dim = config.hidden_size // config.num_heads - self.scale = self.head_dim**-0.5 - - self.query_key_value = QKVParallelLinear( - config.hidden_size, - self.head_dim, - config.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.query_key_value", - ) - self.dense = RowParallelLinear( - config.hidden_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.dense", - ) - - self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, - self.scale) - self.output_dropout = torch.nn.Dropout(config.dropout_prob) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - qkv, _ = self.query_key_value(x) # B, L, 3 * H * D - q, k, v = qkv.chunk(3, dim=-1) - - out = self.attn(q, k, v) - output, _ = self.dense(out) - output = self.output_dropout(output) - return output - - -class MLP(nn.Module): - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - super().__init__() - self.config = config - self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - ) - self.fc2 = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, _ = self.fc1(x) - x = self.activation_fn(x) - x, _ = self.fc2(x) - return x - - -class TransformerLayer(nn.Module): - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - super().__init__() - self.input_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.attention = Attention(config, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.mlp = MLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.post_attention_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - - def forward(self, hidden_states): - attention_input = hidden_states - attention_output = self.input_layernorm( - self.attention(attention_input)) - hidden_states = attention_input + attention_output - mlp_input = hidden_states - mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) - output = mlp_input + mlp_output - return output - - -class Transformer(nn.Module): - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - super().__init__() - self.layers = nn.ModuleList([ - TransformerLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) - - def forward(self, hidden_states): - for layer_module in self.layers: - hidden_states = layer_module(hidden_states) - return hidden_states - - -class GLU(nn.Module): - - def __init__( - self, - config, - in_features, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - """ - The original implementation is the same as: - ```python - self.dense_h_to_4h = ColumnParallelLinear( - config.hidden_size, - config.ffn_hidden_size, - bias=False, - quant_config=quant_config - ) - - self.gate_proj = ColumnParallelLinear( - config.hidden_size, - config.ffn_hidden_size, - bias=False, - quant_config=quant_config - ) - ``` - ``` - gate_proj_output, _ = self.gate_proj(x) - dense_h_to_4h_output, _ = self.dense_h_to_4h(x) - x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) - ``` - - We merge two ColumnParallelLinear into one MergedColumnParallelLinear: - ``` - self.merged_proj = MergedColumnParallelLinear( - config.hidden_size, - [config.ffn_hidden_size] * 2, - bias=False, - quant_config=quant_config - ) - ``` - ``` - x, _ = self.merged_proj(x) - ``` - """ - super().__init__() - self.linear_proj = ReplicatedLinear(in_features, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") - self.norm1 = nn.LayerNorm(config.hidden_size) - self.act1 = nn.GELU() - self.act2 = SiluAndMul() - - self.merged_proj = MergedColumnParallelLinear( - config.hidden_size, [config.ffn_hidden_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.merged_proj") - - self.dense_4h_to_h = RowParallelLinear( - config.ffn_hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.dense_4h_to_h") - - def forward(self, x): - x, _ = self.linear_proj(x) - x = self.act1(self.norm1(x)) - x, _ = self.merged_proj(x) - x = self.act2(x) - x, _ = self.dense_4h_to_h(x) - return x - - -class EVA2CLIPModel(nn.Module): - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', - ): - super().__init__() - vision_config = Namespace(**config.vision_config) - self.patch_embedding = PatchEmbedding(vision_config) - self.transformer = Transformer(vision_config, - quant_config=quant_config, - prefix=f"{prefix}.transformer") - self.linear_proj = GLU(config, - in_features=config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") - self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, - out_channels=config.hidden_size, - kernel_size=2, - stride=2) - self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.scaling_factor = vision_config.scaling_factor - - def forward(self, images: torch.Tensor) -> torch.Tensor: - """ - Parameters: - images : torch.Tensor - Input image tensor with shape (B, C, H, W) - - Returns: - torch.Tensor - Transformed tensor with shape (B, L, D) - """ - x = self.patch_embedding(images) - x = self.transformer(x) - x = x[:, 1:] - - b, s, h = x.shape - grid_size = int(s**0.5) - x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) - x = self.conv(x) - - x = x.flatten(2).transpose(1, 2) - x = self.linear_proj(x) - boi = self.boi.expand(x.shape[0], -1, -1) - eoi = self.eoi.expand(x.shape[0], -1, -1) - x = torch.cat((boi, x, eoi), dim=1) - x = x / self.scaling_factor - return x diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py new file mode 100644 index 000000000..67f19841f --- /dev/null +++ b/vllm/model_executor/models/glm4v.py @@ -0,0 +1,662 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/THUDM/CogAgent +"""Inference-only CogAgent model compatible with THUDM weights.""" +from argparse import Namespace +from typing import List, Literal, Mapping, Optional, TypedDict, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.attention import AttentionMetadata +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BatchFeature, + MultiModalFieldConfig, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import ChatGLMConfig + +from .chatglm import ChatGLMBaseModel, ChatGLMModel +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .utils import flatten_bn, merge_multimodal_embeddings + + +class GLMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +class EVA2CLIPPatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(device=self.proj.weight.device, + dtype=self.proj.weight.dtype) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + self.scale) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + out = self.attn(q, k, v) + output, _ = self.dense(out) + output = self.output_dropout(output) + return output + + +class EVA2CLIPMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = EVA2CLIPAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.mlp = EVA2CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class EVA2CLIPGLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") + + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config) + self.transformer = EVA2CLIPTransformer(vision_config, + quant_config=quant_config, + prefix=f"{prefix}.transformer") + self.linear_proj = EVA2CLIPGLU(config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + +class GLM4VModel(ChatGLMModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + quant_config = vllm_config.quant_config + + self.vision = EVA2CLIPModel(self.config, + quant_config, + prefix=f"{prefix}.vision") + + +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. + """ + + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + vision_config = config.vision_config + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + text_inputs = self.tokenizer(text) + + if len(images) == 0: + image_inputs = {} + else: + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class GLM4VProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self): + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + return tokenizer + + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_feature_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length + + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 + + +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + return ProcessorInputs( + prompt_text=base_text * num_images, + mm_data=mm_data, + ) + + +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id + + def get_replacement(item_idx: int): + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] + + return [ + PromptReplacement( + modality="image", + target=[boi_token_id, image_token_id, eoi_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) +class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + # vision + "fc1", + "fc2", + "merged_proj", + "linear_proj" + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[GLM4VModel] = GLM4VModel, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) + + self.transformer: GLM4VModel + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config["image_size"] + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return GLMVImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + return None + + def _process_image_input( + self, image_input: GLMVImagePixelInputs) -> torch.Tensor: + pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) + + return self.transformer.vision(pixel_values) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.transformer.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ], + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 4b8aeaddb..a45e9463a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,381 +6,35 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -import copy -import math -import re -import unicodedata -from functools import lru_cache, partial -from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable, - List, Literal, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, - TensorType) -from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.logger import init_logger -from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, is_pp_missing_parameter, +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, merge_multimodal_embeddings) - -logger = init_logger(__name__) - - -class QwenImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """ - Shape: `(batch_size * num_images, 3, image_size, image_size)` - - Note that image_size is the value in the vision config to which we resize - the image to in the normalization transform. Currently multi-image support - can only be leveraged by passing image embeddings directly. - """ - - -class QwenImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, 256, hidden_size)` - - `hidden_size` must match the hidden size of the language model backbone - and is stored in the visual config of the model if we have one. - """ - - -QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] - - -class VisualAttention(nn.Module): - """self-attention layer class. - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - kdim: Optional[int] = None, - vdim: Optional[int] = None, - ): - super().__init__() - self.embed_dim = embed_dim - self.kdim = kdim if kdim is not None else embed_dim - self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim \ - and self.vdim == embed_dim - - self.num_heads = num_heads - - # Per attention head and per partition values. - assert embed_dim % num_heads == 0 - self.hidden_size_per_attention_head = embed_dim // num_heads - self.num_attention_heads_per_partition = num_heads - self.hidden_size_per_partition = embed_dim - - # Strided linear layer. - assert self._qkv_same_embed_dim, \ - 'Visual Attention implementation only supports self-attention' - self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) - self.out_proj = ReplicatedLinear(embed_dim, embed_dim) - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - - def forward( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # query/key/value: [sq, b, h] - sq, b, _ = x.size() - mixed_x_layer, _ = self.in_proj(x) - - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = mixed_x_layer.split( - self.hidden_size_per_attention_head, dim=-1) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) - - q_scaled = query_layer / self.norm_factor - if attn_mask is not None: - attention_probs = torch.baddbmm(attn_mask, q_scaled, - key_layer.transpose(-2, -1)) - else: - attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) - attention_probs = attention_probs.softmax(dim=-1) - - value_layer = value_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view( - b, self.num_attention_heads_per_partition, sq, - self.hidden_size_per_attention_head) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - output, _ = self.out_proj(context_layer) - - return output - - -class QwenVMLP(nn.Module): - """MLP for the visual component of the Qwen model.""" - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.c_fc = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config) - self.act_fn = get_act_fn("gelu") - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=True, - quant_config=quant_config, - ) - - def forward(self, x): - x, _ = self.c_fc(x) - x = self.act_fn(x) - x, _ = self.c_proj(x) - return x - - -class VisualAttentionBlock(nn.Module): - - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - - self.ln_1 = norm_layer(d_model) - self.ln_2 = norm_layer(d_model) - mlp_width = int(d_model * mlp_ratio) - self.attn = VisualAttention(d_model, n_head) - self.mlp = QwenVMLP( - hidden_size=d_model, - intermediate_size=mlp_width, - quant_config=quant_config, - ) - - def attention( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, attn_mask=attn_mask) - - def forward( - self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) - x = x + self.mlp(self.ln_2(x)) - return x - - -class TransformerBlock(nn.Module): - - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.width = width - self.layers = layers - - self.resblocks = nn.ModuleList([ - VisualAttentionBlock(width, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) - for _ in range(layers) - ]) - - def get_cast_dtype(self) -> torch.dtype: - return self.resblocks[0].mlp.c_fc.weight.dtype - - def get_cast_device(self) -> torch.device: - return self.resblocks[0].mlp.c_fc.weight.device - - def forward(self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - for r in self.resblocks: - x = r(x, attn_mask=attn_mask) - return x - - -class VisionTransformer(nn.Module): - - def __init__(self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - image_start_id: int = 151857, - quant_config: Optional[QuantizationConfig] = None, - **kwargs): - super().__init__() - image_height, image_width = self.image_size = (image_size, image_size) - patch_height, patch_width = self.patch_size = (patch_size, patch_size) - self.grid_size = (image_height // patch_height, - image_width // patch_width) - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False) - - # class embeddings and positional embeddings - scale = width**-0.5 - self.positional_embedding = nn.Parameter(scale * - torch.randn(256, width)) - - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.ln_pre = norm_layer(width) - self.transformer = TransformerBlock(width, - layers, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) - - self.attn_pool = Resampler2( - grid_size=int(math.sqrt(n_queries)), - embed_dim=output_dim, - num_heads=output_dim // 128, - kv_dim=width, - norm_layer=norm_layer, - adaptive=False, - do_post_projection=False, - ).to( - device=self.positional_embedding.device, - dtype=self.positional_embedding.dtype, - ) - - self.ln_post = norm_layer(output_dim) - self.proj = nn.Parameter( - (output_dim**-0.5) * torch.randn(output_dim, output_dim)) - - self.image_start_id = image_start_id - self.image_end_id = image_start_id + 1 - self.image_pad_id = image_start_id + 2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.to( - dtype=self.transformer.get_cast_dtype(), - device=self.transformer.get_cast_device(), - ) - - # to patches - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], - -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - - x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( - x.size(1)))) - - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.attn_pool(x) - x = self.ln_post(x) - x = x @ self.proj - - return x + maybe_prefix) class QWenMLP(nn.Module): @@ -564,12 +218,6 @@ class QWenModel(nn.Module): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - if (vision_config := getattr(config, "visual", None)): - self.visual = VisionTransformer(**vision_config, - quant_config=quant_config) - else: - self.visual = None - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -592,6 +240,7 @@ class QWenModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states, residual = layer( @@ -610,302 +259,25 @@ class QWenModel(nn.Module): return hidden_states -@lru_cache(maxsize=1) -def _get_tokenizer_without_image_pad( - tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: - """ - The logic of adding image pad tokens should only be applied in - :class:`QWenVLProcessor`, so they are patched out here. - - The definition of the wrapped tokenizer can be found here: - https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py - """ - new_tokenizer = copy.deepcopy(tokenizer) - - class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore - - def tokenize( - self, - text: str, - allowed_special: Union[AbstractSet[str], str] = "all", - disallowed_special: Union[Collection[str], str] = (), - **kwargs, - ) -> list[Union[bytes, str]]: - text = unicodedata.normalize("NFC", text) - - return [ - self.decoder[t] for t in self.tokenizer.encode( - text, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ] - - def _decode( - self, - token_ids: Union[int, List[int]], - skip_special_tokens: bool = False, - errors: Optional[str] = None, - **kwargs, - ) -> str: - if isinstance(token_ids, int): - token_ids = [token_ids] - - return self.tokenizer.decode( - token_ids, - errors=errors or self.errors, - ) - - TokenizerWithoutImagePad.__name__ = \ - f"{tokenizer.__class__.__name__}WithoutImagePad" - - new_tokenizer.__class__ = TokenizerWithoutImagePad - return new_tokenizer - - -class QWenVLProcessor: - """ - This model doesn't define its own HF processor, - so we implement our own one here. - - We call the wrapped tokenizer to automatically insert image pad tokens: - https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245 - - The image processor is defined here: - https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354 - """ +class QWenBaseModel(nn.Module): def __init__( self, - config: PretrainedConfig, - tokenizer: PreTrainedTokenizer, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[QWenModel] = QWenModel, ) -> None: super().__init__() - - self.config = config - self.tokenizer = tokenizer - - if vision_config := getattr(self.config, "visual", None): - image_size = vision_config["image_size"] - - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) - else: - self.image_transform = None - - @property - def image_start_tag(self) -> str: - return self.tokenizer.image_start_tag # type: ignore - - @property - def image_end_tag(self) -> str: - return self.tokenizer.image_end_tag # type: ignore - - @property - def image_pad_tag(self) -> str: - return self.tokenizer.image_pad_tag # type: ignore - - def __call__( - self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchFeature: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - text_inputs = self.tokenizer(text) - - if len(images) == 0: - image_inputs = {} - else: - if self.image_transform is None: - raise ValueError("This model does not support image inputs") - - pixel_values = [self.image_transform(image) for image in images] - image_inputs = {"pixel_values": torch.stack(pixel_values)} - - return BatchFeature( - { - **text_inputs, - **image_inputs, - }, - tensor_type=return_tensors, - ) - - -class QWenVLProcessingInfo(BaseProcessingInfo): - - def get_tokenizer(self) -> PreTrainedTokenizer: - tokenizer = self.ctx.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer) - - return _get_tokenizer_without_image_pad(tokenizer) - - def get_hf_processor(self) -> QWenVLProcessor: - tokenizer = self.ctx.tokenizer - assert isinstance(tokenizer, PreTrainedTokenizer) - - return QWenVLProcessor(self.get_hf_config(), tokenizer) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - - def get_num_image_tokens(self) -> int: - hf_config = self.get_hf_config() - if not (vision_config := getattr(hf_config, "visual", None)): - return 0 - - image_size = vision_config["image_size"] - patch_size = vision_config["patch_size"] - grid_length = image_size // patch_size // 2 - return grid_length * grid_length - - -class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]): - - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.info.get_hf_config() - if not (vision_config := getattr(hf_config, "visual", None)): - return ProcessorInputs(prompt_text="", mm_data={}) - - processor = self.info.get_hf_processor() - img_start = processor.image_start_tag - img_end = processor.image_end_tag - - target_width = target_height = vision_config["image_size"] - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" - for i in range(1, num_images + 1)), - mm_data=mm_data, - ) - - -class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]): - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - # Drops anything between / tags; encoding with the tokenizer - # will automatically add the image pads for the context. - prompt, num_matched_images = re.subn( - r"(Picture \d*: ).*?(<\/img>\n)", - r"\1\2", - prompt, - ) - - image_data = mm_data.get("images") - if image_data is not None: - assert isinstance(image_data, list) - - num_images = len(image_data) - if num_matched_images != num_images: - logger.warning( - "Number of matched image placeholders %s doesn't match " - "the number of expected images %s; check your placeholder " - "formatting.", num_matched_images, num_images) - - return super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - ) - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - ) - - def _get_prompt_replacements( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - hf_config = self.info.get_hf_config() - if not hasattr(hf_config, "visual"): - return [] - - tokenizer = self.info.get_tokenizer() - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore - - processor = self.info.get_hf_processor() - img_start_id = special_tokens[processor.image_start_tag] - img_end_id = special_tokens[processor.image_end_tag] - img_pad_id = special_tokens[processor.image_pad_tag] - - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [img_pad_id] * num_image_tokens - - return [ - PromptReplacement( - modality="image", - target=[img_start_id, img_end_id], - replacement=PromptReplacementDetails( - full=[img_start_id] + image_tokens + [img_end_id], - features=image_tokens, - ), - ) - ] - - -class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.transformer = QWenModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = transformer_type(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -916,104 +288,6 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.visual["image_size"] - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[QwenImageInputs]: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - - if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - return QwenImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), - ) - - if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - - return QwenImageEmbeddingInputs( - type="image_embeds", - data=flatten_bn(image_embeds), - ) - - return None - - def _process_image_input(self, - image_input: QwenImageInputs) -> torch.Tensor: - if image_input["type"] == "image_embeds": - return image_input["data"] - - assert self.transformer.visual is not None - return self.transformer.visual(image_input["data"]) - - def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return None - - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - ) -> torch.Tensor: - inputs_embeds = self.transformer.get_input_embeddings(input_ids) - - if multimodal_embeddings is not None: - assert self.transformer.visual is not None - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.transformer.visual.image_pad_id) - - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: - inputs_embeds = None - - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. - elif inputs_embeds is None: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) - input_ids = None - - hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - def compute_logits( self, hidden_states: torch.Tensor, @@ -1072,26 +346,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): return loaded_params -class QWenLLM(QWenBaseModel): - packed_modules_mapping = { - "c_attn": ["c_attn"], - "gate_up_proj": [ - "w2", - "w1", - ], - } - # LoRA specific attributes - supported_lora_modules = [ - "c_attn", - "gate_up_proj", - "c_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] - - -class QWenVL(QWenBaseModel, SupportsMultiModal): +class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -1104,62 +359,35 @@ class QWenVL(QWenBaseModel, SupportsMultiModal): "c_attn", "gate_up_proj", "c_proj", - # visual module - "out_proj", - "in_proj", - "c_fc", - # resampler - "kv_proj", ] embedding_modules = {} embedding_padding_modules = [] - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="transformer.h", - connector="transformer.visual.attn_pool", - tower_model="transformer.visual.transformer") - - -@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor, - info=QWenVLProcessingInfo, - dummy_inputs=QWenVLDummyInputsBuilder) -class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): - """ - QWenLMHeadModel is not only applicable to LLM but also to VL, which is not - conducive to the current integration logic of LoRA in vLLM. Therefore, it - is necessary to separate them. - """ - # Ensure that the LoRA support check passes when the class is not - # initialized, but set all these attributes to empty. - # These will be updated when an instance class is selected - packed_modules_mapping = {} - supported_lora_modules = [] - embedding_modules = {} - embedding_padding_modules = [] - - def __new__( - cls, - vllm_config: VllmConfig, - prefix: str = "", - ) -> QWenBaseModel: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config + if hasattr(config, "visual"): + hf_overrides = { + "architectures": ["QwenVLForConditionalGeneration"] + } + raise RuntimeError( + "The configuration of this model indicates that it supports " + "vision inputs, but you instantiated the text-only version " + "of this model. Please use the vision model by setting " + f"`--hf-overrides {hf_overrides!r}`") + + super().__init__(vllm_config=vllm_config, prefix=prefix) - # Initialize VL - if hasattr(config, "visual"): # noqa: SIM108 - instance_cls = QWenVL - # Initialize LLM - else: - instance_cls = QWenLLM - - # quant_config references base class members, - # so update values before init is called - cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) - cls.supported_lora_modules += instance_cls.supported_lora_modules - cls.embedding_modules.update(instance_cls.embedding_modules) - cls.embedding_padding_modules += instance_cls.embedding_padding_modules - return instance_cls(vllm_config=vllm_config, prefix=prefix) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py new file mode 100644 index 000000000..5316eb7e0 --- /dev/null +++ b/vllm/model_executor/models/qwen_vl.py @@ -0,0 +1,794 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +"""Inference-only Qwen-VL model compatible with HuggingFace weights.""" + +import copy +import math +import re +import unicodedata +from functools import lru_cache, partial +from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping, + Optional, TypedDict, Union) + +import torch +from torch import nn +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, + TensorType) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .qwen import QWenBaseModel, QWenModel +from .utils import flatten_bn, merge_multimodal_embeddings + + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' + self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) + self.out_proj = ReplicatedLinear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # query/key/value: [sq, b, h] + sq, b, _ = x.size() + mixed_x_layer, _ = self.in_proj(x) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output, _ = self.out_proj(context_layer) + + return output + + +class QwenVLMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu") + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = QwenVLMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) + + def attention( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock(width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, + image_width // patch_width) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + + self.attn_pool = Resampler2( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, + ) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 + self.image_pad_id = image_start_id + 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + +class QwenVLModel(QWenModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) + + +@lru_cache(maxsize=1) +def _get_tokenizer_without_image_pad( + tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: + """ + The logic of adding image pad tokens should only be applied in + :class:`QwenVLProcessor`, so they are patched out here. + + The definition of the wrapped tokenizer can be found here: + https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py + """ + new_tokenizer = copy.deepcopy(tokenizer) + + class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore + + def tokenize( + self, + text: str, + allowed_special: Union[AbstractSet[str], str] = "all", + disallowed_special: Union[Collection[str], str] = (), + **kwargs, + ) -> list[Union[bytes, str]]: + text = unicodedata.normalize("NFC", text) + + return [ + self.decoder[t] for t in self.tokenizer.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ] + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: Optional[str] = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + + return self.tokenizer.decode( + token_ids, + errors=errors or self.errors, + ) + + TokenizerWithoutImagePad.__name__ = \ + f"{tokenizer.__class__.__name__}WithoutImagePad" + + new_tokenizer.__class__ = TokenizerWithoutImagePad + return new_tokenizer + + +class QwenVLProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + We call the wrapped tokenizer to automatically insert image pad tokens: + https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245 + + The image processor is defined here: + https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + vision_config = config.visual + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + text_inputs = self.tokenizer(text) + + if len(images) == 0: + image_inputs = {} + else: + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class QwenVLProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self) -> PreTrainedTokenizer: + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + + return _get_tokenizer_without_image_pad(tokenizer) + + def get_hf_processor(self) -> QwenVLProcessor: + tokenizer = self.ctx.tokenizer + assert isinstance(tokenizer, PreTrainedTokenizer) + + return QwenVLProcessor(self.get_hf_config(), tokenizer) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.visual + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length + + +class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.visual + + processor = self.info.get_hf_processor() + img_start = processor.image_start_tag + img_end = processor.image_end_tag + + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n" + for i in range(1, num_images + 1)), + mm_data=mm_data, + ) + + +class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + + image_data = mm_data.get("images") + if image_data is not None: + assert isinstance(image_data, list) + + num_images = len(image_data) + assert num_matched_images == num_images + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + tokenizer = self.info.get_tokenizer() + special_tokens: dict[str, + int] = tokenizer.special_tokens # type: ignore + + processor = self.info.get_hf_processor() + img_start_id = special_tokens[processor.image_start_tag] + img_end_id = special_tokens[processor.image_end_tag] + img_pad_id = special_tokens[processor.image_pad_tag] + + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [img_pad_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[img_start_id, img_end_id], + replacement=PromptReplacementDetails( + full=[img_start_id] + image_tokens + [img_end_id], + features=image_tokens, + ), + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(QwenVLMultiModalProcessor, + info=QwenVLProcessingInfo, + dummy_inputs=QwenVLDummyInputsBuilder) +class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, + SupportsMultiModal): + packed_modules_mapping = { + "c_attn": ["c_attn"], + "gate_up_proj": [ + "w2", + "w1", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "c_attn", + "gate_up_proj", + "c_proj", + # visual module + "out_proj", + "in_proj", + "c_fc", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.h", + connector="transformer.visual.attn_pool", + tower_model="transformer.visual.transformer") + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[QwenVLModel] = QwenVLModel, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) + + self.transformer: QwenVLModel + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.visual["image_size"] + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[QwenImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return QwenImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return QwenImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + return None + + def _process_image_input(self, + image_input: QwenImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + return self.transformer.visual(image_input["data"]) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.transformer.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.transformer.visual.image_pad_id) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 198b6d134..08c4642b4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -39,7 +39,7 @@ _TEXT_GENERATION_MODELS = { "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - # ChatGLMModel supports multimodal + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), "Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), @@ -90,7 +90,7 @@ _TEXT_GENERATION_MODELS = { "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - # QWenLMHeadModel supports multimodal + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), @@ -156,10 +156,9 @@ _MULTIMODAL_MODELS = { "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), @@ -175,7 +174,7 @@ _MULTIMODAL_MODELS = { "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 -- GitLab From 37dfa6003782d872a375cc41429ca16fe241f83f Mon Sep 17 00:00:00 2001 From: Vaibhav Jain Date: Thu, 13 Feb 2025 20:22:22 +0530 Subject: [PATCH 122/253] [Bugfix] Missing Content Type returns 500 Internal Server Error (#13193) --- tests/entrypoints/openai/test_basic.py | 16 ++++++++++ vllm/entrypoints/openai/api_server.py | 42 +++++++++++++++++--------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 0d44a7611..a970981b7 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer): max_tokens=10) assert len(response.choices) == 1 + + +@pytest.mark.asyncio +async def test_request_wrong_content_type(server: RemoteOpenAIServer): + + chat_input = [{"role": "user", "content": "Write a long story"}] + client = server.get_async_client() + + with pytest.raises(openai.APIStatusError): + await client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_headers={ + "Content-Type": "application/x-www-form-urlencoded" + }) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 588a7781c..b50a72f3a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,7 +19,7 @@ from http import HTTPStatus from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -252,6 +252,15 @@ async def build_async_engine_client_from_engine_args( multiprocess.mark_process_dead(engine_process.pid) +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + if content_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + + router = APIRouter() @@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize") +@router.post("/tokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): assert_never(generator) -@router.post("/detokenize") +@router.post("/detokenize", dependencies=[Depends(validate_json_request)]) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -379,7 +388,8 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions") +@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) @@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings") +@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) @@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) -@router.post("/pooling") +@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) @@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): assert_never(generator) -@router.post("/score") +@router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) @@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/score") +@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @with_cancellation async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( @@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -@router.post("/rerank") +@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) @@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): assert_never(generator) -@router.post("/v1/rerank") +@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( @@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -@router.post("/v2/rerank") +@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) @@ -582,7 +592,7 @@ if envs.VLLM_SERVER_DEV_MODE: return Response(status_code=200) -@router.post("/invocations") +@router.post("/invocations", dependencies=[Depends(validate_json_request)]) async def invocations(raw_request: Request): """ For SageMaker, routes requests to other handlers based on model `task`. @@ -632,7 +642,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: "Lora dynamic loading & unloading is enabled in the API server. " "This should ONLY be used for local development!") - @router.post("/v1/load_lora_adapter") + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): handler = models(raw_request) @@ -643,7 +654,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter") + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): handler = models(raw_request) -- GitLab From d84cef76eb9e16190cfdd97ae24511c8c819f179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Thu, 13 Feb 2025 16:23:45 +0100 Subject: [PATCH 123/253] [Frontend] Add `/v1/audio/transcriptions` OpenAI API endpoint (#12909) --- .buildkite/test-pipeline.yaml | 12 +- .../serving/openai_compatible_server.md | 13 + .../openai_transcription_client.py | 23 ++ requirements-common.txt | 7 +- requirements-test.in | 1 + requirements-test.txt | 5 + .../openai/correctness/__init__.py | 0 .../test_lmeval.py} | 2 +- .../test_transcription_api_correctness.py | 166 ++++++++++ .../openai/test_transcription_validation.py | 122 +++++++ tests/test_config.py | 1 + vllm/assets/audio.py | 5 + vllm/config.py | 11 +- vllm/entrypoints/openai/api_server.py | 43 ++- vllm/entrypoints/openai/protocol.py | 163 +++++++++- vllm/entrypoints/openai/serving_engine.py | 6 +- .../openai/serving_transcription.py | 305 ++++++++++++++++++ vllm/model_executor/models/interfaces.py | 27 ++ vllm/model_executor/models/registry.py | 12 +- vllm/model_executor/models/whisper.py | 5 +- 20 files changed, 910 insertions(+), 19 deletions(-) create mode 100644 examples/online_serving/openai_transcription_client.py create mode 100644 tests/entrypoints/openai/correctness/__init__.py rename tests/entrypoints/openai/{test_accuracy.py => correctness/test_lmeval.py} (98%) create mode 100644 tests/entrypoints/openai/correctness/test_transcription_api_correctness.py create mode 100644 tests/entrypoints/openai/test_transcription_validation.py create mode 100644 vllm/entrypoints/openai/serving_transcription.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e26b1bf38..9991060a3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -117,7 +117,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -205,7 +205,7 @@ steps: - VLLM_USE_V1=1 pytest -v -s v1/e2e # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api - - pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" @@ -339,6 +339,14 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI API correctness + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ + - label: Encoder Decoder tests # 5min source_file_dependencies: - vllm/ diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 82ef54c16..64439475f 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -41,6 +41,8 @@ We currently support the following OpenAI APIs: - *Note: `parallel_tool_calls` and `user` parameters are ignored.* - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). +- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: @@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s :end-before: end-chat-embedding-extra-params ::: +(transcriptions-api)= + +### Transcriptions API + +Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); +you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. + + + +Code example: + (tokenizer-api)= ### Tokenizer API diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py new file mode 100644 index 000000000..bd3c02a8a --- /dev/null +++ b/examples/online_serving/openai_transcription_client.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() +winning_call = AudioAsset('winning_call').get_local_path() + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/requirements-common.txt b/requirements-common.txt index cfa020256..0b7253cc1 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,12 +8,11 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' +fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) -uvicorn[standard] -pydantic >= 2.9 # Required for fastapi >= 0.113.0 +pydantic >= 2.9 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/requirements-test.in b/requirements-test.in index 229d743ec..ecf874ecc 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -19,6 +19,7 @@ pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests +jiwer # required for audio tests timm # required for internvl test torch==2.5.1 torchaudio==2.5.1 diff --git a/requirements-test.txt b/requirements-test.txt index e032aac71..648a2626c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -66,6 +66,7 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # jiwer # nltk # ray colorama==0.4.6 @@ -187,6 +188,8 @@ jinja2==3.1.4 # via # datamodel-code-generator # torch +jiwer==3.0.5 + # via -r requirements-test.in jmespath==1.0.1 # via # boto3 @@ -470,6 +473,8 @@ pyyaml==6.0.2 # timm # transformers # vocos +rapidfuzz==3.12.1 + # via jiwer ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 diff --git a/tests/entrypoints/openai/correctness/__init__.py b/tests/entrypoints/openai/correctness/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/correctness/test_lmeval.py similarity index 98% rename from tests/entrypoints/openai/test_accuracy.py rename to tests/entrypoints/openai/correctness/test_lmeval.py index df25780cd..ebb2ea4d9 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -13,7 +13,7 @@ import pytest from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer +from ....utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" NUM_CONCURRENT = 500 diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py new file mode 100644 index 000000000..19d4735b9 --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +This simulates real work usage of the API and makes sure that the frontend and +AsyncLLMEngine are working correctly. +""" +import asyncio +import io +import time +from statistics import mean, median +from typing import List + +import librosa +import pytest +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from transformers import AutoTokenizer + +from ....utils import RemoteOpenAIServer + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, client, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + + # Warmup call as the first `librosa.load` server-side is quite slow. + audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] + _ = await bound_transcribe(model, sem, client, (audio, sr), "") + + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + start = time.perf_counter() + results = asyncio.run( + process_dataset(model, client, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. +@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) +# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. +@pytest.mark.parametrize( + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) +# NOTE: Expected WER measured with equivalent hf.transformers args: +# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. +@pytest.mark.parametrize("expected_wer", [12.744980]) +def test_wer_correctness(model_name, + dataset_repo, + expected_wer, + n_examples=-1, + max_concurrent_request=None): + with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + dataset = load_hf_dataset(dataset_repo) + + if not max_concurrent_request: + # No max concurrency + max_concurrent_request = n_examples if n_examples > 0\ + else len(dataset) + + client = remote_server.get_async_client() + wer = run_evaluation(model_name, client, dataset, + max_concurrent_request, n_examples) + if expected_wer: + torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py new file mode 100644 index 000000000..5d4a5de4b --- /dev/null +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +# imports for guided decoding tests +import io +import json + +import librosa +import numpy as np +import openai +import pytest +import soundfile as sf + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.mark.asyncio +async def test_basic_audio(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. + prompt = "THE FIRST WORDS I SPOKE" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Mary had a little lamb," in out + # This should "force" whisper to continue prompt in all caps + transcription_wprompt = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_capital = json.loads(transcription_wprompt)['text'] + assert prompt not in out_capital + + +@pytest.mark.asyncio +async def test_bad_requests(mary_had_lamb): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + # Expect audio too long: repeat the timeseries + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=buffer, + language="en", + temperature=0.0) + + +@pytest.mark.asyncio +async def test_non_asr_model(winning_call): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="en", + temperature=0.0) + assert res.code == 400 and not res.text + assert res.message == "The model does not support Transcriptions API" + + +@pytest.mark.asyncio +async def test_completion_endpoints(): + # text to text model + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + assert res.code == 400 + assert res.message == "The model does not support Chat Completions API" + + res = await client.completions.create(model=model_name, prompt="Hello") + assert res.code == 400 + assert res.message == "The model does not support Completions API" diff --git a/tests/test_config.py b/tests/test_config.py index 2dfae218b..3fb83b4c0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,7 @@ from vllm.platforms import current_platform ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), + ("openai/whisper-small", "transcription", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index d9e51082e..0203dc092 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -28,6 +29,10 @@ class AudioAsset: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/config.py b/vllm/config.py index 1740871e7..10004b8f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward"] + "score", "reward", "transcription"] _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", - "draft"] + "draft", "transcription"] -RunnerType = Literal["generate", "pooling", "draft"] +RunnerType = Literal["generate", "pooling", "draft", "transcription"] _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { "generate": ["generate"], "pooling": ["embed", "classify", "score", "reward"], "draft": ["draft"], + "transcription": ["transcription"], } _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { @@ -484,6 +485,8 @@ class ModelConfig: return "embed" if ModelRegistry.is_cross_encoder_model(architectures): return "score" + if ModelRegistry.is_transcription_model(architectures): + return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ # Other models follow this pattern @@ -516,6 +519,8 @@ class ModelConfig: runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them + "transcription": + ModelRegistry.is_transcription_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures), } diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b50a72f3a..ad391d673 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,10 +16,10 @@ from argparse import Namespace from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -61,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable @@ -75,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -327,6 +331,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -520,6 +528,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], + raw_request: Request): + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + audio_data = await request.file.read() + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): @@ -832,6 +865,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83b841826..2bcfdc235 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -8,9 +8,10 @@ from argparse import Namespace from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch +from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1426,3 +1427,163 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + #https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: Optional[str] = None + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9efb5e6fa..785117ca1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ErrorResponse, RerankRequest, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest) + TokenizeCompletionRequest, + TranscriptionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, + TranscriptionRequest] class TextTokensPrompt(TypedDict): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000..da4930e0e --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import io +from typing import AsyncGenerator, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +# TODO get from processor.feature_extractor.chunk_length +MAX_AUDIO_CLIP_DURATION_S = 30 + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + async def _preprocess_transcription( + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> PromptType: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + y, sr = librosa.load(bytes_) + if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + raise ValueError( + f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + "exceeded.") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (y, sr), + }, + }, + "decoder_prompt": + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" + } + return cast(PromptType, prompt) + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ErrorResponse]: + """Transcription API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI transcription API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + # TODO cmpl->transcription? + request_id = f"cmpl-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None + try: + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # TODO(rob): figure out a way to pipe streaming in. + # Non-streaming response. + try: + async for op in result_generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db1..a0a1b69ad 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -441,3 +441,30 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 08c4642b4..7260d973b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from vllm.logger import init_logger from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp) + supports_pp, supports_transcription) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -224,6 +224,7 @@ class _ModelInfo: has_inner_state: bool is_attention_free: bool is_hybrid: bool + supports_transcription: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -237,7 +238,7 @@ class _ModelInfo: has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - ) + supports_transcription=supports_transcription(model)) class _BaseRegisteredModel(ABC): @@ -485,6 +486,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_hybrid + def is_transcription_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0a3011d36..0b5060720 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,7 @@ from vllm.multimodal.audio import resample_audio from vllm.sequence import SequenceData from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers logger = init_logger(__name__) @@ -637,7 +637,8 @@ def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_whisper_audio_tokens) -class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): +class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", -- GitLab From bffddd9a05a6d0d3ea04b7baad1966e4f57f94c7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 13 Feb 2025 20:51:30 +0000 Subject: [PATCH 124/253] Add label if pre-commit passes (#12527) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .github/workflows/add_label_precommit.yml | 38 +++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .github/workflows/add_label_precommit.yml diff --git a/.github/workflows/add_label_precommit.yml b/.github/workflows/add_label_precommit.yml new file mode 100644 index 000000000..a88b44f03 --- /dev/null +++ b/.github/workflows/add_label_precommit.yml @@ -0,0 +1,38 @@ +name: Add label on pre-commit success +on: + workflow_run: + workflows: [pre-commit] + types: [requested, completed] +jobs: + add-label-on-pre-commit-success: + runs-on: ubuntu-latest + if: ${{ github.event.workflow_run.conclusion == 'success' }} + steps: + - name: Add label + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: ['pre-commit-passed'] + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + remove-label-on-pre-commit-not-success: + runs-on: ubuntu-latest + if: ${{ github.event.workflow_run.conclusion != 'success' }} + steps: + - name: Remove label + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + github.rest.issues.removeLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: ['pre-commit passed'] + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -- GitLab From 2344192a55022d84ba3bb8d33b1ec38724f54fed Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 13 Feb 2025 18:43:37 -0500 Subject: [PATCH 125/253] Optimize moe_align_block_size for deepseek_v3 (#12850) Signed-off-by: mgoin --- csrc/moe/moe_align_sum_kernels.cu | 52 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 3 +- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index c072744f0..d7be76945 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel( } // taken from -// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a +// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template __global__ void sgl_moe_align_block_size_kernel( scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel, int32_t* cumsum) { __shared__ int32_t shared_counts[32][8]; - __shared__ int32_t local_offsets[256]; const int warp_id = threadIdx.x / 32; - const int lane_id = threadIdx.x % 32; const int experts_per_warp = 8; const int my_expert_start = warp_id * experts_per_warp; + // Initialize shared_counts for this warp's experts for (int i = 0; i < experts_per_warp; ++i) { if (my_expert_start + i < num_experts) { shared_counts[warp_id][i] = 0; } } + __syncthreads(); + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; @@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel( __syncthreads(); + // Single thread computes cumulative sum and total tokens if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { @@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel( __syncthreads(); + // Assign expert IDs to blocks if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } - local_offsets[threadIdx.x] = cumsum[threadIdx.x]; } +} - __syncthreads(); - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { +// taken from +// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 +template +__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } } @@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(num_experts == 256, + "sgl_moe_align_block_size kernel only supports deepseek v3."); + VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors + // calc needed amount of shared mem for `cumsum` tensors auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - // torch::Tensor token_cnts_buffer = - // torch::empty({(num_experts + 1) * num_experts}, options_int); torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); + torch::zeros({num_experts + 1}, options_int); - auto kernel = vllm::moe::sgl_moe_align_block_size_kernel; - kernel<<<1, 1024, 0, stream>>>( + auto align_kernel = + vllm::moe::sgl_moe_align_block_size_kernel; + align_kernel<<<1, 1024, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + + const int block_threads = 256; + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel; + sort_kernel<<>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), topk_ids.numel()); }); } diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f14200e02..d0b6249e1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -596,7 +596,7 @@ def moe_align_block_size( dtype=torch.int32, device=topk_ids.device) if num_experts >= 224: - if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: + if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: moe_align_block_size_triton( topk_ids, num_experts, @@ -606,6 +606,7 @@ def moe_align_block_size( num_tokens_post_pad, ) else: + # Currently requires num_experts=256 ops.sgl_moe_align_block_size( topk_ids, num_experts, -- GitLab From c1e37bf71b99f3c8cb4923c0b620918a93170206 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 13 Feb 2025 19:01:14 -0500 Subject: [PATCH 126/253] [Kernel][Bugfix] Refactor and Fix CUTLASS 2:4 Sparse Kernels (#13198) Signed-off-by: Tyler Michael Smith --- CMakeLists.txt | 10 +- .../epilogue/scaled_mm_epilogues_c3x.hpp | 69 ++++- csrc/ops.h | 3 +- .../cutlass_w8a8/c3x/scaled_mm.cuh | 13 +- .../cutlass_w8a8/scaled_mm_c2x.cuh | 11 +- csrc/sparse/cutlass/sparse_compressor_c3x.cu | 165 ----------- csrc/sparse/cutlass/sparse_compressor_c3x.cuh | 90 ++++++ .../sparse/cutlass/sparse_compressor_entry.cu | 42 --- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 264 +++++++++--------- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 232 +++++++++------ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 30 ++ csrc/torch_bindings.cpp | 6 +- tests/kernels/test_cutlass_2of4_sparse.py | 81 +++++- vllm/_custom_ops.py | 17 +- .../compressed_tensors/compressed_tensors.py | 7 - .../schemes/compressed_tensors_24.py | 9 +- 16 files changed, 576 insertions(+), 473 deletions(-) delete mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cuh delete mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 244ceb721..8e8f7adf6 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") + # Please keep this in sync with FetchContent_Declare line below. + set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG v3.7.0 GIT_PROGRESS TRUE @@ -266,7 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/sparse/cutlass/sparse_compressor_entry.cu" "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( @@ -359,8 +360,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -476,7 +476,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c590c66a6..583fa3c45 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -16,6 +16,30 @@ namespace vllm::c3x { using namespace cute; +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + /* * This class provides the common load descriptors for the * ScaledEpilogue[...] classes @@ -174,6 +198,49 @@ struct ScaledEpilogueBias } }; +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + /* * This epilogue directly supports per-tensor azp in int32 form. * As opposed to the per-token epilogue below, this epilogue only has an azp_adj @@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken } }; -}; // namespace vllm::c3x \ No newline at end of file +}; // namespace vllm::c3x diff --git a/csrc/ops.h b/csrc/ops.h index 70e864cc6..460078896 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, - torch::Tensor& e, torch::Tensor const& a); +std::vector cutlass_sparse_compress(torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 9227ebb73..d2f43e2b7 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -53,12 +53,17 @@ struct cutlass_3x_gemm { using EVTCompute = typename Epilogue::EVTCompute; + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -69,8 +74,8 @@ struct cutlass_3x_gemm { using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index f2fae4b66..ce7cf2f35 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -103,14 +103,19 @@ struct cutlass_2x_gemm { using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + // clang-format off using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; using KernelType = ArchGuard - -#if defined CUDA_VERSION && CUDA_VERSION >= 12020 -#include "sparse_scaled_mm_c3x.cuh" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/packed_stride.hpp" -// clang-format on - -using namespace cute; -using namespace vllm; - -/// Make A structured sparse by replacing elements with 0 and compress it -template -bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || - a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); - TORCH_CHECK(a.dim() == 2) - // Check for strides and alignment - TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity - TORCH_CHECK(a.stride(1) == 1) - - int m = a.size(0); - int k = a.size(1); - - // Sparse kernel setup; this kernel is not used for matmul, - // but just for setting up the compressor utility - // A matrix configuration - using ElementA = ElementA_; - using LayoutTagA = cutlass::layout::RowMajor; - constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - // B matrix configuration - using ElementB = ElementA; - using LayoutTagB = cutlass::layout::ColumnMajor; - constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - // C/D matrix configuration - using ElementC = float; - using LayoutTagC = cutlass::layout::ColumnMajor; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - // Core kernel configurations - using ElementAccumulator = ElementAcc_; - using TileShape = Shape<_128, _128, _128>; - using TileShapeRef = Shape<_128, _128, _64>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = typename std::conditional< - std::is_same_v, - cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, - cutlass::gemm::KernelTmaWarpSpecialized>::type; - - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; - using ProblemShape = Shape; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, - AlignmentC, ElementC, LayoutTagC, AlignmentC, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, - LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideA = cutlass::gemm::TagToStrideA_t; - using StrideE = StrideA; - - using StrideA = Stride, int64_t>; - - // The n (=1) dimension does not matter for the compressor - typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; - - using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; - using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; - - using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; - using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; - - // Offline compressor kernel - using CompressorUtility = - cutlass::transform::kernel::StructuredSparseCompressorUtility< - ProblemShape, ElementA, LayoutTagA, SparseConfig>; - - using CompressorKernel = - cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, ElementA, LayoutTagA, SparseConfig, - cutlass::arch::Sm90>; - - using Compressor = - cutlass::transform::device::TransformUniversalAdapter; - - auto [M, N, K, L] = prob_shape; - - StrideA stride_A; - stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - - CompressorUtility compressor_utility(prob_shape, stride_A); - - int ME = compressor_utility.get_metadata_m_physical(); - int KE = compressor_utility.get_metadata_k_physical(); - int KC = compressor_utility.get_tensorA_k_physical(); - - auto a_ptr = static_cast(a.data_ptr()); - - auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); - auto a_meta_ptr = static_cast( - a_meta.data_ptr()); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - typename Compressor::Arguments arguments{ - prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - CUTLASS_CHECK(compressor_op.can_implement(arguments)); - CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); - CUTLASS_CHECK(compressor_op.run()); - CUDA_CHECK(cudaDeviceSynchronize()); - - return true; -} - -bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - if (a.dtype() == torch::kBFloat16) { - return cutlass_sparse_compress(a_nzs, a_meta, - a); - } else if (a.dtype() == torch::kFloat16) { - return cutlass_sparse_compress(a_nzs, a_meta, a); - } else if (a.dtype() == torch::kFloat8_e4m3fn) { - return cutlass_sparse_compress(a_nzs, a_meta, - a); - } else if (a.dtype() == torch::kInt8) { - return cutlass_sparse_compress(a_nzs, a_meta, a); - } - return false; -} -#endif diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh new file mode 100644 index 000000000..2cc235f3a --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh @@ -0,0 +1,90 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12020 +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +// clang-format on + +using namespace cute; +using namespace vllm; + +using CompressorResult = std::tuple; +/// Make A structured sparse by replacing elements with 0 and compress it +template +CompressorResult cutlass_sparse_compress(torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + using GemmKernel = typename Gemm::KernelType; + using ElementA = typename Gemm::ElementAB; + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + + int m = a.size(0); + int k = a.size(1); + using ProblemShape = typename GemmKernel::ProblemShape; + ProblemShape prob_shape{m, 1, k, 1}; + + int64_t lda = a.stride(0); + using StrideA = Stride, int64_t>; + StrideA a_stride{lda, Int<1>{}, 0}; + + using CompressorUtility = typename Gemm::CompressorUtility; + CompressorUtility compressor_utility(prob_shape, a_stride); + + // Allocate buffers for the metadata E and the compressed matrix A + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int MC = compressor_utility.get_tensorA_m_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto const a_meta_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto const a_nzs_options = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + + auto a_meta = torch::zeros({ME, KE}, a_meta_options); + auto a_nzs = torch::zeros({MC, KC}, a_nzs_options); + + auto a_ptr = static_cast(a.data_ptr()); + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast(a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + using Compressor = typename Gemm::Compressor; + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return {a_meta, a_nzs}; +} + +#endif diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu deleted file mode 100644 index 3401761c1..000000000 --- a/csrc/sparse/cutlass/sparse_compressor_entry.cu +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include -#include - -#include "cutlass_extensions/common.hpp" - -#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X -bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a); -#endif - -bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - // Checks for conformality - TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); - TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && - a_nzs.size(1) * 2 == a.size(1) && - a_meta.size(1) * 2 * 4 == a.size(1)); - // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && - a_meta.stride(1) == 1); // Row-major - TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression - - at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); - int32_t version_num = get_sm_version_num(); - - // Guard against compilation issues for sm90 kernels -#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X - if (version_num >= 90) { - return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); - } -#endif - - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled cutlass_scaled_sparse_mm for a compute capability less than " - "CUDA device capability: ", - version_num); -} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 5a1879787..3dcaa6373 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -9,17 +9,30 @@ using namespace cute; using namespace vllm; +struct GemmCallerTraits { + using return_type = void; + + template + static return_type invoke(Args&&... args) { + return cutlass_sparse_gemm_caller(std::forward(args)...); + } +}; + +struct GemmCompressorTraits { + using return_type = CompressorResult; + + template + static return_type invoke(Args&&... args) { + return cutlass_sparse_compress(std::forward(args)...); + } +}; + template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch( + uint32_t m, uint32_t n, Args&&... args) { + static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass3xGemm8 = typename sm90_fp8_config_8::Cutlass3xGemm; - uint32_t const n = bt_nzs.size(0); - uint32_t const m = a.size(0); // Batch size uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 4096 || n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 128) { if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 256) { if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else { if (n == 6144 || n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } // Otherwise the default heuristic if (mp2 <= 64) { // n in [1, 64] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 128) { // n in (64, 128] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 256) { // n in (128, 256] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { // n in (256, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat16); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); - - using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; - - // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kBFloat16); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); - + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch( + uint32_t m, uint32_t n, Args&&... args) { using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; - // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch( + uint32_t m, uint32_t n, Args&&... args) { + static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - uint32_t const n = out.size(1); bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } +// Dispatch to GEMM implementations based on element types template