Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
......@@ -80,7 +80,6 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any, Dict, Optional
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
_config: Optional[Dict[str, Any]] = None
_config: Optional[dict[str, Any]] = None
@contextmanager
......@@ -19,7 +19,7 @@ def override_config(config):
_config = old_config
def get_config() -> Optional[Dict[str, Any]]:
def get_config() -> Optional[dict[str, Any]]:
return _config
......
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -61,7 +61,7 @@ def _moe_permute(
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
......
......@@ -3,7 +3,7 @@
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
......@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str:
block_shape: Optional[list[int]] = None) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
......@@ -638,7 +638,7 @@ def get_moe_configs(
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
) -> Optional[Dict[int, Any]]:
) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -670,7 +670,7 @@ def get_moe_configs(
return None
def get_moe_wna16_block_config(config: Dict[str,
def get_moe_wna16_block_config(config: dict[str,
int], use_moe_wna16_cuda: bool,
num_valid_tokens: int, size_k: int, size_n: int,
num_experts: int, group_size: int,
......@@ -742,8 +742,8 @@ def get_default_config(
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
block_shape: Optional[list[int]] = None,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
......@@ -795,13 +795,13 @@ def get_default_config(
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
):
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
......@@ -855,7 +855,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -895,7 +895,7 @@ def grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
......@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
......@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
pass
......@@ -1046,7 +1046,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
......@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
......@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
......@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
block_shape: Optional[list[int]] = None):
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
......@@ -1452,7 +1452,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1497,7 +1497,7 @@ def fused_moe(
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
- block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization.
Returns:
......
......@@ -2,7 +2,7 @@
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional
import torch
import torch.nn.functional as F
......@@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
......@@ -338,7 +338,7 @@ def determine_expert_map(
global_num_experts (int): The total number of experts in the model.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
......@@ -909,7 +909,7 @@ class FusedMoE(torch.nn.Module):
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:
num_experts: int) -> list[tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -153,7 +153,7 @@ def moe_align_block_size(
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -15,7 +15,7 @@ def moe_permute(
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.
......
# SPDX-License-Identifier: Apache-2.0
from functools import cache
from typing import List, Optional, Tuple
from typing import Optional
import torch
......@@ -97,7 +97,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck
......@@ -142,7 +142,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=hidden_states_dtype)
......@@ -280,7 +280,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
......@@ -372,14 +372,14 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> Tuple[torch.Tensor, ...]:
renormalize: bool) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
token_expert_indices, gating_output,
renormalize)
return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
......@@ -395,7 +395,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.
......
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
from typing import Optional
import torch
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
......@@ -22,8 +22,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
......
# SPDX-License-Identifier: Apache-2.0
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
ops.fused_add_rms_norm(
x,
......@@ -57,7 +57,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
......@@ -119,7 +119,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
......@@ -157,7 +157,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
......@@ -174,7 +174,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None:
......@@ -194,7 +194,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
......@@ -244,7 +244,7 @@ class GemmaRMSNorm(CustomOp):
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
......@@ -267,7 +267,7 @@ class GemmaRMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
......@@ -276,7 +276,7 @@ class GemmaRMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
......@@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp):
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
......@@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]],
shard_spec: list[tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
......
# SPDX-License-Identifier: Apache-2.0
from enum import IntEnum
from typing import List, Optional, Union
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -46,7 +46,7 @@ class SimplePooler(nn.Module):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "SimplePooler":
if pooling_type == PoolingType.LAST:
assert step_tag_id is None and returned_token_ids is None
......@@ -174,7 +174,7 @@ class StepPool(SimplePooler):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
):
super().__init__(normalize=normalize, softmax=softmax)
......@@ -259,7 +259,7 @@ class Pooler(nn.Module):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
) -> SimplePooler:
return SimplePooler.from_pooling_type(
pooling_type=PoolingType[pooler_config.pooling_type]
......
# SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args
from typing import Literal, get_args
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -76,7 +76,7 @@ def register_quantization_config(quantization: str):
return _wrapper
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
......@@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, Type[QuantizationConfig]] = {
method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
......
......@@ -4,7 +4,7 @@
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
......@@ -98,7 +98,7 @@ def generic_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int],
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
......@@ -136,7 +136,7 @@ def optimized_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int],
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
......@@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig):
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
......@@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
......@@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig):
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
modules_to_not_convert: Optional[list[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
......@@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half]
@classmethod
......@@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig):
return 75
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
......@@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig):
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
......@@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig):
return None
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
......@@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
from torch.nn import Parameter
......@@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
......@@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig):
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
......@@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig):
return None
@classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
......@@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
......@@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import nn
......@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool:
method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
......@@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
......@@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
......@@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
......@@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
return None
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
......@@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
"quantization config.")
@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
......@@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig):
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
......@@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig):
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: Dict[str, Any],
keys: List[str],
def get_from_keys(config: dict[str, Any],
keys: list[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
......@@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig):
return default
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
......@@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......@@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
......@@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
......@@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return "bitsandbytes"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
......@@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
return 70
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"adapter_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
def get_safe_value(config, keys, default_value=None):
try:
......@@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return None
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split('.')
......@@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params
......
# SPDX-License-Identifier: Apache-2.0
from contextlib import suppress
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
from typing import Any, Literal, Optional, cast
import torch
from compressed_tensors.config import (CompressionFormat,
......@@ -38,20 +38,20 @@ logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
......@@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
......@@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
......@@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def _parse_sparsity_config(
cls, config: Dict[str, Any]
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements
......@@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
......@@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: Dict[str, Any] = dict()
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
......@@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self,
......@@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
......@@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment