Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
import torch
......@@ -45,7 +45,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
self.weight_quant = weight_config
self.input_quant = input_config
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
......@@ -18,8 +18,8 @@ __all__ = ["QuarkW4A4MXFP4"]
class QuarkW4A4MXFP4(QuarkScheme):
def __init__(self, weight_quant_spec: Dict[str, Any],
input_quant_spec: Dict[str, Any]):
def __init__(self, weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
......@@ -74,7 +74,7 @@ class QuarkW4A4MXFP4(QuarkScheme):
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn import Parameter
......@@ -88,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
......@@ -17,7 +17,7 @@ logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
input_symmetric: Optional[bool]):
......@@ -31,7 +31,7 @@ class QuarkW8A8Int8(QuarkScheme):
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
......
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
from typing import Any, Optional
def deep_compare(dict1: Any, dict2: Any) -> bool:
......@@ -21,7 +22,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
......
......@@ -12,7 +12,7 @@ possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from typing import Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
......@@ -23,7 +23,7 @@ class KVCacheQuantSchema(BaseModel):
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
scaling_factor: dict[int, dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
......@@ -24,7 +24,7 @@ class TorchAOConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
......@@ -32,11 +32,11 @@ class TorchAOConfig(QuantizationConfig):
return 75
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
......@@ -60,7 +60,7 @@ class TorchAOConfig(QuantizationConfig):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
def get_scaled_act_names(self) -> list[str]:
return []
......@@ -97,7 +97,7 @@ class TorchAOLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
from torch.nn import Module
......@@ -31,7 +31,7 @@ class Int8TpuConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
......@@ -40,11 +40,11 @@ class Int8TpuConfig(QuantizationConfig):
"This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
......@@ -62,7 +62,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
self.quant_config = quant_config
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
......@@ -77,7 +77,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight)
def _quantize_weight(
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -51,7 +51,7 @@ def _check_bitblas_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
......@@ -133,7 +133,7 @@ def verify_bitblas_supports_shape(output_size_per_partition: int,
def check_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
try:
verify_bitblas_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
......@@ -166,7 +166,7 @@ def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
def bitblas_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
......
......@@ -4,7 +4,7 @@
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
......@@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
......@@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -114,7 +114,7 @@ direct_register_custom_op(
def input_to_float8(
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
......@@ -129,7 +129,7 @@ def input_to_float8(
def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
......@@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
......@@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
......@@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
......@@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
......
# SPDX-License-Identifier: Apache-2.0
import re
from copy import deepcopy
from typing import Dict, Optional, Union
from typing import Optional, Union
import torch
......@@ -52,7 +52,7 @@ def get_dynamic_override(
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool,
None] = None) -> Union[Dict, int, bool, None]:
None] = None) -> Union[dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
......
......@@ -5,7 +5,7 @@ import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
......@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
......@@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
......@@ -58,7 +58,7 @@ def input_to_int8(
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
block_size: list[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
......@@ -211,7 +211,7 @@ def per_token_group_quant_int8(
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
......@@ -225,7 +225,7 @@ def per_token_group_quant_int8(
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
......@@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
......@@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import torch
......@@ -10,19 +10,19 @@ MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]:
if zero_points:
return [scalar_types.uint4, scalar_types.uint8]
else:
return [scalar_types.uint4b8, scalar_types.uint8b128]
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
return [torch.float16, torch.bfloat16]
def check_machete_supports_shape(in_features: int, out_featrues: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return False, "Input features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import numpy
import torch
......@@ -70,7 +70,7 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
......@@ -143,7 +143,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
......@@ -231,16 +231,16 @@ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
......
# SPDX-License-Identifier: Apache-2.0
"""Utility functions used for tests and benchmarks"""
from typing import List, Optional
from typing import Optional
import numpy as np
import torch
......@@ -64,9 +64,9 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
......
......@@ -2,7 +2,6 @@
"""Utility functions used for tests and benchmarks"""
import random
from typing import List
import numpy
import torch
......@@ -373,19 +372,19 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
def get_scale_perms_24():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return scale_perm, scale_perm_single
def get_weight_perm_24(num_bits: int):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
......
# SPDX-License-Identifier: Apache-2.0
from typing import List
import numpy
import torch
......@@ -34,10 +32,10 @@ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
def get_qqq_scale_perms():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
......@@ -46,9 +44,9 @@ def get_qqq_scale_perms():
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def get_qqq_weight_perm(num_bits: int, quant_type: str):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
......
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple
import torch
......@@ -9,7 +8,7 @@ OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
......
# SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
from typing import Optional
import numpy
import torch
......@@ -15,7 +16,7 @@ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
......@@ -56,9 +57,9 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
group_shape: tuple[int, int],
quant_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, \
"currently `scaled_quantize` only supports floating point dtypes " \
......@@ -97,9 +98,9 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[Tuple[int, int]] = None,
group_shape: Optional[tuple[int, int]] = None,
out_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
......@@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
ignored_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
......@@ -81,7 +81,7 @@ def all_close_1d(x: torch.Tensor) -> bool:
def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
# Create channelwise buffer
weight_scale_channel = torch.empty((sum(logical_widths), 1),
dtype=torch.float32,
......@@ -99,7 +99,7 @@ def convert_to_channelwise(
def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
......@@ -136,7 +136,7 @@ def maybe_create_device_identity():
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
output_shape: list, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
......@@ -154,7 +154,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
......@@ -177,7 +177,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
......@@ -198,7 +198,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
......@@ -228,7 +228,7 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
output_shape: list,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
......@@ -384,7 +384,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment