Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
...@@ -80,7 +80,6 @@ exclude = [ ...@@ -80,7 +80,6 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"] "vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[Dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
@contextmanager @contextmanager
...@@ -19,7 +19,7 @@ def override_config(config): ...@@ -19,7 +19,7 @@ def override_config(config):
_config = old_config _config = old_config
def get_config() -> Optional[Dict[str, Any]]: def get_config() -> Optional[dict[str, Any]]:
return _config return _config
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util import importlib.util
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -61,7 +61,7 @@ def _moe_permute( ...@@ -61,7 +61,7 @@ def _moe_permute(
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
block_m: int, block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
""" """
Determine the sorted_token_ids, expert_ids for the given problem size. Determine the sorted_token_ids, expert_ids for the given problem size.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
...@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, mul_routed_weight: bool,
top_k: int, top_k: int,
config: Dict[str, Any], config: dict[str, Any],
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[list[int]] = None) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int, def get_config_file_name(E: int,
N: int, N: int,
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str: block_shape: Optional[list[int]] = None) -> str:
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else block_shape_selector = ("" if not block_shape or not all(block_shape) else
...@@ -638,7 +638,7 @@ def get_moe_configs( ...@@ -638,7 +638,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = None, block_n: Optional[int] = None,
block_k: Optional[int] = None, block_k: Optional[int] = None,
) -> Optional[Dict[int, Any]]: ) -> Optional[dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -670,7 +670,7 @@ def get_moe_configs( ...@@ -670,7 +670,7 @@ def get_moe_configs(
return None return None
def get_moe_wna16_block_config(config: Dict[str, def get_moe_wna16_block_config(config: dict[str,
int], use_moe_wna16_cuda: bool, int], use_moe_wna16_cuda: bool,
num_valid_tokens: int, size_k: int, size_n: int, num_valid_tokens: int, size_k: int, size_n: int,
num_experts: int, group_size: int, num_experts: int, group_size: int,
...@@ -742,8 +742,8 @@ def get_default_config( ...@@ -742,8 +742,8 @@ def get_default_config(
topk: int, topk: int,
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
) -> Dict[str, int]: ) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1] # BLOCK_SIZE_K must be divisible by block_shape[1]
...@@ -795,13 +795,13 @@ def get_default_config( ...@@ -795,13 +795,13 @@ def get_default_config(
def try_get_optimal_moe_config( def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...], w1_shape: tuple[int, ...],
w2_shape: Tuple[int, ...], w2_shape: tuple[int, ...],
top_k: int, top_k: int,
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
): ):
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config() override_config = get_config()
...@@ -855,7 +855,7 @@ def fused_topk( ...@@ -855,7 +855,7 @@ def fused_topk(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
...@@ -895,7 +895,7 @@ def grouped_topk( ...@@ -895,7 +895,7 @@ def grouped_topk(
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch") "Number of tokens mismatch")
...@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
...@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake( ...@@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[list[int]] = None) -> None:
pass pass
...@@ -1046,7 +1046,7 @@ def outplace_fused_experts( ...@@ -1046,7 +1046,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
...@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake( ...@@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[list[int]] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor: allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8 if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
...@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input( ...@@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8: if use_fp8_w8a8:
assert B_scale is not None assert B_scale is not None
if block_shape is None: if block_shape is None:
...@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None): block_shape: Optional[list[int]] = None):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
...@@ -1452,7 +1452,7 @@ def fused_moe( ...@@ -1452,7 +1452,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1497,7 +1497,7 @@ def fused_moe( ...@@ -1497,7 +1497,7 @@ def fused_moe(
a1. a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2. a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise - block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization. quantization.
Returns: Returns:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def determine_expert_map( def determine_expert_map(
ep_size: int, ep_rank: int, ep_size: int, ep_rank: int,
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
""" """
Calculates how many experts should be assigned to each rank for EP and Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are creates a mapping from global to local expert index. Experts are
...@@ -338,7 +338,7 @@ def determine_expert_map( ...@@ -338,7 +338,7 @@ def determine_expert_map(
global_num_experts (int): The total number of experts in the model. global_num_experts (int): The total number of experts in the model.
Returns: Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing: tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned - local_num_experts (int): The number of experts assigned
to the current rank. to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape - expert_map (Optional[torch.Tensor]): A tensor of shape
...@@ -909,7 +909,7 @@ class FusedMoE(torch.nn.Module): ...@@ -909,7 +909,7 @@ class FusedMoE(torch.nn.Module):
def make_expert_params_mapping( def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str, ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]: num_experts: int) -> list[tuple[str, str, int, str]]:
return [ return [
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -153,7 +153,7 @@ def moe_align_block_size( ...@@ -153,7 +153,7 @@ def moe_align_block_size(
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False pad_sorted_ids: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple from typing import Optional
import torch import torch
...@@ -15,7 +15,7 @@ def moe_permute( ...@@ -15,7 +15,7 @@ def moe_permute(
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None, align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1 fill_invalid_expert: int = -1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
This function expands and permutes activation to gather uncontinuous tokens This function expands and permutes activation to gather uncontinuous tokens
for each expert. for each expert.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import cache from functools import cache
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
...@@ -97,7 +97,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( ...@@ -97,7 +97,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, a1_scale: torch.Tensor,
block_shape: List[int], block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1 from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck from aiter.fused_moe_bf16_asm import moe_sorting_ck
...@@ -142,7 +142,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( ...@@ -142,7 +142,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, a1_scale: torch.Tensor,
block_shape: List[int], block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=hidden_states_dtype) return torch.empty_like(a1, dtype=hidden_states_dtype)
...@@ -280,7 +280,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, ...@@ -280,7 +280,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor: allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -372,14 +372,14 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, ...@@ -372,14 +372,14 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> Tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
token_expert_indices, gating_output, token_expert_indices, gating_output,
renormalize) renormalize)
return topk_weights, topk_indices return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
""" """
Applies shuffle_weight function from AITER to each Applies shuffle_weight function from AITER to each
input tensor and returns them. input tensor and returns them.
...@@ -395,7 +395,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: ...@@ -395,7 +395,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def expand_weights(*tensors: torch.Tensor, def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
""" """
Expands the dimensions of input tensors. Expands the dimensions of input tensors.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from math import prod from math import prod
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.utils import cdiv from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
""" """
Shrink the given tensor and apply the given view to it. This is Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches. used to resize the intermediate fused_moe caches.
...@@ -22,8 +22,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: ...@@ -22,8 +22,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
def _fp8_quantize( def _fp8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]], block_shape: Optional[list[int]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Perform fp8 quantization on the inputs. If a block_shape Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked. is provided, the output will be blocked.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Custom normalization layers.""" """Custom normalization layers."""
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
...@@ -57,7 +57,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -57,7 +57,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
def rocm_aiter_fused_add_rms_norm( def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter import aiter as rocm_aiter
...@@ -119,7 +119,7 @@ class RMSNorm(CustomOp): ...@@ -119,7 +119,7 @@ class RMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
...@@ -157,7 +157,7 @@ class RMSNorm(CustomOp): ...@@ -157,7 +157,7 @@ class RMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
...@@ -174,7 +174,7 @@ class RMSNorm(CustomOp): ...@@ -174,7 +174,7 @@ class RMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.kernels import rms_norm from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm() HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None: if HPUFusedRMSNorm is None:
...@@ -194,7 +194,7 @@ class RMSNorm(CustomOp): ...@@ -194,7 +194,7 @@ class RMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
...@@ -244,7 +244,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -244,7 +244,7 @@ class GemmaRMSNorm(CustomOp):
variance_epsilon: float, variance_epsilon: float,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype orig_dtype = x.dtype
if residual is not None: if residual is not None:
...@@ -267,7 +267,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -267,7 +267,7 @@ class GemmaRMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x, return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual) residual)
...@@ -276,7 +276,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -276,7 +276,7 @@ class GemmaRMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling(): if torch.compiler.is_compiling():
return self.forward_native(x, residual) return self.forward_native(x, residual)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
gate: torch.Tensor, gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.tp_size > 1 or self.n_groups != 1: if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate) return self.forward_native(x, gate)
...@@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): ...@@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
def mamba_v2_sharded_weight_loader( def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]], shard_spec: list[tuple[int, int, float]],
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
) -> LoaderFunction: ) -> LoaderFunction:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from enum import IntEnum from enum import IntEnum
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -46,7 +46,7 @@ class SimplePooler(nn.Module): ...@@ -46,7 +46,7 @@ class SimplePooler(nn.Module):
normalize: bool, normalize: bool,
softmax: bool, softmax: bool,
step_tag_id: Optional[int] = None, step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None, returned_token_ids: Optional[list[int]] = None,
) -> "SimplePooler": ) -> "SimplePooler":
if pooling_type == PoolingType.LAST: if pooling_type == PoolingType.LAST:
assert step_tag_id is None and returned_token_ids is None assert step_tag_id is None and returned_token_ids is None
...@@ -174,7 +174,7 @@ class StepPool(SimplePooler): ...@@ -174,7 +174,7 @@ class StepPool(SimplePooler):
normalize: bool, normalize: bool,
softmax: bool, softmax: bool,
step_tag_id: Optional[int] = None, step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None, returned_token_ids: Optional[list[int]] = None,
): ):
super().__init__(normalize=normalize, softmax=softmax) super().__init__(normalize=normalize, softmax=softmax)
...@@ -259,7 +259,7 @@ class Pooler(nn.Module): ...@@ -259,7 +259,7 @@ class Pooler(nn.Module):
normalize: bool, normalize: bool,
softmax: bool, softmax: bool,
step_tag_id: Optional[int] = None, step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None, returned_token_ids: Optional[list[int]] = None,
) -> SimplePooler: ) -> SimplePooler:
return SimplePooler.from_pooling_type( return SimplePooler.from_pooling_type(
pooling_type=PoolingType[pooler_config.pooling_type] pooling_type=PoolingType[pooler_config.pooling_type]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args from typing import Literal, get_args
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -76,7 +76,7 @@ def register_quantization_config(quantization: str): ...@@ -76,7 +76,7 @@ def register_quantization_config(quantization: str):
return _wrapper return _wrapper
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS: if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}") raise ValueError(f"Invalid quantization method: {quantization}")
...@@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, Type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# and https://arxiv.org/pdf/2401.06118.pdf # and https://arxiv.org/pdf/2401.06118.pdf
import math import math
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -98,7 +98,7 @@ def generic_dequantize_gemm( ...@@ -98,7 +98,7 @@ def generic_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int], output_partition_sizes: list[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], ) output_shape = input.shape[:-1] + (scales.shape[0], )
...@@ -136,7 +136,7 @@ def optimized_dequantize_gemm( ...@@ -136,7 +136,7 @@ def optimized_dequantize_gemm(
codebooks: torch. codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int], output_partition_sizes: list[int],
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
...@@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig): ...@@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig):
return "aqlm" return "aqlm"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig): ...@@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig):
return 60 return 60
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] # no extra configs. return [] # no extra configs.
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"]) in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"]) num_code_books = cls.get_from_keys(config, ["num_codebooks"])
...@@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase): ...@@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
del output_size # Unused. del output_size # Unused.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
...@@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig): ...@@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
zero_point: bool, zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None, modules_to_not_convert: Optional[list[str]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight_bits = weight_bits self.weight_bits = weight_bits
...@@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig): ...@@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods: def get_name(self) -> QuantizationMethods:
return "awq" return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half] return [torch.half]
@classmethod @classmethod
...@@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig): ...@@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig):
return 75 return 75
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
return [ return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
...@@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig): ...@@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig):
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
...@@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig): ...@@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig):
return None return None
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert) return any(module_name in prefix for module_name in modules_to_not_convert)
...@@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0: if input_size_per_partition % self.quant_config.group_size != 0:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
...@@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, zero_point: bool, def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]], modules_to_not_convert: Optional[list[str]],
full_config: Dict[str, Any]) -> None: full_config: dict[str, Any]) -> None:
super().__init__() super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32 self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
...@@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig):
return "awq_marlin" return "awq_marlin"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
...@@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig):
return None return None
@classmethod @classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits") num_bits = quant_config.get("bits")
...@@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC): ...@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
def method_has_implemented_embedding( def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool: method_class: type[QuantizeMethodBase]) -> bool:
""" """
Not all quant methods have embedding implemented, so we need to check that Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function it exists for our given method. We check this by making sure the function
...@@ -68,7 +68,7 @@ class QuantizationConfig(ABC): ...@@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# mapping is updated by models as they initialize # mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict() self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> QuantizationMethods: def get_name(self) -> QuantizationMethods:
...@@ -76,7 +76,7 @@ class QuantizationConfig(ABC): ...@@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
raise NotImplementedError raise NotImplementedError
...@@ -93,13 +93,13 @@ class QuantizationConfig(ABC): ...@@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
raise NotImplementedError raise NotImplementedError
...@@ -115,7 +115,7 @@ class QuantizationConfig(ABC): ...@@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
return None return None
@staticmethod @staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config.""" """Get a value from the model's quantization config."""
for key in keys: for key in keys:
if key in config: if key in config:
...@@ -124,7 +124,7 @@ class QuantizationConfig(ABC): ...@@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
"quantization config.") "quantization config.")
@staticmethod @staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str], def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any: default: Any) -> Any:
"""Get a optional value from the model's quantization config.""" """Get a optional value from the model's quantization config."""
try: try:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
...@@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig): ...@@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig):
return "bitblas" return "bitblas"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]
@classmethod @classmethod
...@@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig): ...@@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig):
return 70 return 70
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
@staticmethod @staticmethod
def get_from_keys(config: Dict[str, Any], def get_from_keys(config: dict[str, Any],
keys: List[str], keys: list[str],
default: Any = None) -> Any: default: Any = None) -> Any:
"""Get a value from the model's quantization config.""" """Get a value from the model's quantization config."""
for key in keys: for key in keys:
...@@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig): ...@@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig):
return default return default
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1) group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False) desc_act = cls.get_from_keys(config, ["desc_act"], False)
...@@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase): ...@@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
...@@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase): ...@@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional from typing import Any, Optional
import torch import torch
...@@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False, bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False, llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[List[str]] = None, llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_threshold: float = 6.0, llm_int8_threshold: float = 6.0,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return "bitsandbytes" return "bitsandbytes"
@classmethod @classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16] return [torch.float32, torch.float16, torch.bfloat16]
@classmethod @classmethod
...@@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
return 70 return 70
@staticmethod @staticmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> list[str]:
return [ return [
"adapter_config.json", "adapter_config.json",
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
def get_safe_value(config, keys, default_value=None): def get_safe_value(config, keys, default_value=None):
try: try:
...@@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return None return None
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components # Split the prefix into its dot-separated components
components = prefix.split('.') components = prefix.split('.')
...@@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase): ...@@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
from bitsandbytes.nn import Int8Params from bitsandbytes.nn import Int8Params
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import suppress from contextlib import suppress
from typing import Any, Dict, List, Literal, Optional, Tuple, cast from typing import Any, Literal, Optional, cast
import torch import torch
from compressed_tensors.config import (CompressionFormat, from compressed_tensors.config import (CompressionFormat,
...@@ -38,20 +38,20 @@ logger = init_logger(__name__) ...@@ -38,20 +38,20 @@ logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]] QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
def __init__( def __init__(
self, self,
target_scheme_map: Dict[str, Any], target_scheme_map: dict[str, Any],
ignore: List[str], ignore: list[str],
quant_format: str, quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig], sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: List[str], sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
self.ignore = ignore self.ignore = ignore
...@@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]
@classmethod @classmethod
...@@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return None return None
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", [])) ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format")) quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config( target_scheme_map = cls._quantization_scheme_map_from_config(
config=config) config=config)
...@@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def _parse_sparsity_config( def _parse_sparsity_config(
cls, config: Dict[str, Any] cls, config: dict[str, Any]
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]: ) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
""" """
:param config: The `quantization_config` dictionary from config.json :param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements :return: A tuple with two elements
...@@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_config = SparsityCompressionConfig.model_validate( sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config) sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
target: sparsity_config target: sparsity_config
for target in sparsity_config.targets or list() for target in sparsity_config.targets or list()
} }
...@@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def _quantization_scheme_map_from_config( def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
""" """
:param config: The `quantization_config` dictionary from config.json :param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding :return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations quantization_args for weights and input activations
""" """
target_scheme_map: Dict[str, Any] = dict() target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format")) quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing # The quant_config has multiple config_groups, each containing
...@@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return target_scheme_map return target_scheme_map
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
def _check_scheme_supported(self, def _check_scheme_supported(self,
...@@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
""" """
...@@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): ...@@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config) super().__init__(quant_config)
@staticmethod @staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
""" """
Validator for the kv cache scheme. Useful for controlling the Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM kv cache quantization schemes, that are being supported in vLLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment