Unverified Commit b3601da6 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[Mypy] Fix mypy for `vllm/model_executor` (except `vllm/model_executor/layers`) (#37904)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent dc78c2c9
...@@ -29,7 +29,7 @@ SEPARATE_GROUPS = [ ...@@ -29,7 +29,7 @@ SEPARATE_GROUPS = [
"tests", "tests",
# v0 related # v0 related
"vllm/lora", "vllm/lora",
"vllm/model_executor", "vllm/model_executor/layers",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
......
...@@ -96,6 +96,7 @@ def sparse_attn_indexer( ...@@ -96,6 +96,7 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1 topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
assert prefill_metadata is not None
# Get the full shared workspace buffers once (will allocate on first use) # Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager() workspace_manager = current_workspace_manager()
...@@ -170,6 +171,8 @@ def sparse_attn_indexer( ...@@ -170,6 +171,8 @@ def sparse_attn_indexer(
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = attn_metadata.decode
assert decode_metadata is not None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim], # kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim], # we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2) kv_cache = kv_cache.unsqueeze(-2)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from collections.abc import Generator from collections.abc import Generator
from typing import TYPE_CHECKING, cast
import gguf import gguf
import regex as re import regex as re
...@@ -27,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -27,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -350,10 +354,9 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -350,10 +354,9 @@ class GGUFModelLoader(BaseModelLoader):
for name, weight_type in weight_type_map.items() for name, weight_type in weight_type_map.items()
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight") if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
] ]
logger.debug( logger.debug("GGUF unquantized modules: %s", unquant_names)
"GGUF unquantized modules: %s", if TYPE_CHECKING:
unquant_names, vllm_config.quant_config = cast(GGUFConfig, vllm_config.quant_config)
)
vllm_config.quant_config.unquantized_modules.extend(unquant_names) vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
......
...@@ -27,28 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader): ...@@ -27,28 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
self._is_distributed = False self._is_distributed: bool = False
if load_config.model_loader_extra_config: if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance( if isinstance(distributed := extra_config.get("distributed"), bool):
extra_config.get("distributed"), bool self._is_distributed = distributed
): if isinstance(concurrency := extra_config.get("concurrency"), int):
self._is_distributed = extra_config.get("distributed") os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(concurrency)
if isinstance(memory_limit := extra_config.get("memory_limit"), int):
if "concurrency" in extra_config and isinstance( os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(memory_limit)
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT") runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL") aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
...@@ -93,7 +81,7 @@ class RunaiModelStreamerLoader(BaseModelLoader): ...@@ -93,7 +81,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
return hf_weights_files return hf_weights_files
def _get_weights_iterator( def _get_weights_iterator(
self, model_or_path: str, revision: str self, model_or_path: str, revision: str | None
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision) hf_weights_files = self._prepare_weights(model_or_path, revision)
......
...@@ -6,6 +6,7 @@ import glob ...@@ -6,6 +6,7 @@ import glob
import os import os
import time import time
from collections.abc import Generator from collections.abc import Generator
from copy import copy
from typing import Any from typing import Any
import torch import torch
...@@ -42,7 +43,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -42,7 +43,7 @@ class ShardedStateLoader(BaseModelLoader):
extra_config = ( extra_config = (
{} {}
if load_config.model_loader_extra_config is None if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy() else copy(load_config.model_loader_extra_config)
) )
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config: if extra_config:
......
...@@ -674,7 +674,8 @@ def serialize_vllm_model( ...@@ -674,7 +674,8 @@ def serialize_vllm_model(
key = f.read() key = f.read()
encryption_params = EncryptionParams(key=key) encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri if (output_file := tensorizer_args.tensorizer_uri) is None:
raise ValueError("tensorizer_uri must be specified for serialization.")
if tensorizer_config._is_sharded: if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
......
...@@ -121,6 +121,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -121,6 +121,7 @@ class TensorizerLoader(BaseModelLoader):
if parallel_config.tensor_parallel_size > 1: if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
assert self.tensorizer_config.tensorizer_uri is not None
self.tensorizer_config.tensorizer_uri = ( self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank() self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
) )
......
...@@ -6,6 +6,7 @@ import inspect ...@@ -6,6 +6,7 @@ import inspect
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
import torch import torch
from torch import nn from torch import nn
...@@ -71,7 +72,7 @@ def initialize_model( ...@@ -71,7 +72,7 @@ def initialize_model(
model_class, model_class,
) )
# try to be compatible with old-style model class # try to be compatible with old-style model class
kwargs = {} kwargs: dict[str, Any] = {}
if "prefix" in all_params: if "prefix" in all_params:
kwargs["prefix"] = prefix kwargs["prefix"] = prefix
if "config" in all_params: if "config" in all_params:
......
...@@ -257,6 +257,8 @@ def convert_bin_to_safetensor_file( ...@@ -257,6 +257,8 @@ def convert_bin_to_safetensor_file(
def get_quant_config( def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig: ) -> QuantizationConfig:
if model_config.quantization is None:
raise ValueError("Model quantization method is not specified in the config.")
quant_cls = get_quantization_config(model_config.quantization) quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file # GGUF doesn't have config file
...@@ -307,6 +309,11 @@ def get_quant_config( ...@@ -307,6 +309,11 @@ def get_quant_config(
# if hf_quant_config is None, we will try to get config from # if hf_quant_config is None, we will try to get config from
# hf_overrides # hf_overrides
hf_overrides = model_config.hf_overrides hf_overrides = model_config.hf_overrides
if not isinstance(hf_overrides, dict):
raise ValueError(
"hf_overrides must be a dict for get_quant_config "
"to get the quantization config from it."
)
quantization_config_file = hf_overrides.get("quantization_config_file", None) quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None: if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"): if hasattr(quant_cls, "from_config_file"):
...@@ -1087,7 +1094,7 @@ def multi_thread_pt_weights_iterator( ...@@ -1087,7 +1094,7 @@ def multi_thread_pt_weights_iterator(
def get_gguf_extra_tensor_names( def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: dict[str, str] gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> list[str]: ) -> list[str]:
reader = gguf.GGUFReader(gguf_file) reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys()) expected_gguf_keys = set(gguf_to_hf_name_map.keys())
...@@ -1097,7 +1104,7 @@ def get_gguf_extra_tensor_names( ...@@ -1097,7 +1104,7 @@ def get_gguf_extra_tensor_names(
def get_gguf_weight_type_map( def get_gguf_weight_type_map(
gguf_file: str, gguf_to_hf_name_map: dict[str, str] gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> dict[str, str]: ) -> dict[str, str]:
""" """
Return GGUF mapped weight's name and its quant type Return GGUF mapped weight's name and its quant type
...@@ -1111,7 +1118,7 @@ def get_gguf_weight_type_map( ...@@ -1111,7 +1118,7 @@ def get_gguf_weight_type_map(
def gguf_quant_weights_iterator( def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: dict[str, str] gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
""" """
Iterate over the quant weights in the model gguf files and convert Iterate over the quant weights in the model gguf files and convert
......
...@@ -154,8 +154,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -154,8 +154,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset") shard_offset: int = kwargs["shard_offset"]
shard_size = kwargs.get("shard_size") shard_size: int = kwargs["shard_size"]
# TODO: move these to PackedColumnParameter and PackedvLLMParameter # TODO: move these to PackedColumnParameter and PackedvLLMParameter
if ( if (
...@@ -176,10 +176,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -176,10 +176,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset") shard_offset: int = kwargs["shard_offset"]
shard_size = kwargs.get("shard_size") shard_size: int = kwargs["shard_size"]
shard_id = kwargs.get("shard_id") shard_id: str = kwargs["shard_id"]
num_heads = kwargs.get("num_heads") num_heads: int = kwargs["num_heads"]
# TODO: move these to PackedColumnParameter and PackedvLLMParameter # TODO: move these to PackedColumnParameter and PackedvLLMParameter
if ( if (
...@@ -191,10 +191,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -191,10 +191,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
) )
param_data = self.data param_data = self.data
shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads shard_id_int = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow( loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size self.output_dim, shard_id_int * shard_size, shard_size
) )
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment