Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Tuple, Union
from collections.abc import Iterable
from typing import Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
......@@ -12,7 +13,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser
class StreamingToolReconstructor:
def __init__(self, assert_one_tool_per_delta: bool = True):
self.tool_calls: List[ToolCall] = []
self.tool_calls: list[ToolCall] = []
self.other_content: str = ""
self._assert_one_tool_per_delta = assert_one_tool_per_delta
......@@ -72,7 +73,7 @@ def run_tool_extraction(
request: Union[ChatCompletionRequest, None] = None,
streaming: bool = False,
assert_one_tool_per_delta: bool = True,
) -> Tuple[Union[str, None], List[ToolCall]]:
) -> tuple[Union[str, None], list[ToolCall]]:
if streaming:
reconstructor = run_tool_extraction_streaming(
tool_parser,
......@@ -106,7 +107,7 @@ def run_tool_extraction_streaming(
reconstructor = StreamingToolReconstructor(
assert_one_tool_per_delta=assert_one_tool_per_delta)
previous_text = ""
previous_tokens: List[int] = []
previous_tokens: list[int] = []
for delta in model_deltas:
token_delta = [
tool_parser.vocab.get(token)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
......@@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]:
-> tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
......@@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:
-> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Type
import pytest
import torch
......@@ -86,7 +85,7 @@ def test_act_and_mul(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
activation: Type[torch.nn.Module],
activation: type[torch.nn.Module],
num_tokens: int,
d: int,
dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
......@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
keys_lst: list[torch.Tensor] = []
values_lst: list[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
......@@ -133,7 +133,7 @@ def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
......@@ -166,7 +166,7 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
......@@ -334,7 +334,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
cu_seq_lens: list[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
......@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs: List[torch.Tensor] = []
ref_outputs: list[torch.Tensor] = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
......@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
......@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = []
values_lst: List[torch.Tensor] = []
keys_lst: list[torch.Tensor] = []
values_lst: list[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
......@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
......@@ -331,7 +331,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
cu_seq_lens: list[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
......@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode()
def test_varlen_blocksparse_attention_prefill(
num_seqs: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
blocksparse_local_blocks: int,
blocksparse_vert_stride: int,
......
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List, Tuple
import pytest
import torch
......@@ -74,7 +73,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping: List[Tuple[int, int]] = []
block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
......@@ -342,7 +341,7 @@ def test_reshape_and_cache_flash(
@torch.inference_mode()
def test_swap_blocks(
kv_cache_factory,
direction: Tuple[str, str],
direction: tuple[str, str],
num_mappings: int,
num_heads: int,
head_size: int,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
......@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@torch.inference_mode()
def test_merge_kernel(
num_tokens: int,
num_heads: Tuple[int, int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
):
......@@ -85,8 +85,8 @@ CASES = [
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
num_heads: Tuple[int, int],
seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
......
......@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type
import pytest
import torch
......@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
......@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
......@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
......@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
......@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
......
......@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Tuple, Type
import pytest
import torch
......@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t()
......@@ -167,7 +166,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool):
# Create tensors
......
......@@ -243,7 +243,7 @@ def _decoder_attn_setup(
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
'''
Set up test vectors & data structures for self-attention test.
......@@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt: TestPoint,
test_rsrcs: TestResources,
block_base_addr: int = 0,
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
) -> tuple[PhaseTestParameters, PhaseTestParameters]:
'''
Set up test vectors & data structures for cross-attention test.
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
......@@ -24,8 +24,8 @@ def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
......@@ -35,7 +35,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
......@@ -88,8 +88,8 @@ def ref_paged_attn(
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
kv_lens: List[int],
num_heads: Tuple[int, int],
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
......@@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv(
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import flashinfer
import pytest
......@@ -19,8 +19,8 @@ def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
......@@ -30,7 +30,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
......@@ -78,8 +78,8 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
......@@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float]) -> None:
......@@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
def test_flashinfer_prefill_with_paged_fp8_kv(
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
pytest.skip("TODO: fix the accuracy issue")
......@@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_fp8_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union
from typing import Optional, Union
import pytest
import torch
......@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
-> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual)
......@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
......@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub)
......@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
......@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub)
......
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import List
import pytest
import torch
......@@ -16,7 +15,7 @@ GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
def get_gguf_sample_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> List[ReaderTensor]:
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
......
......@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import math
from dataclasses import dataclass, fields
from typing import List, Optional, Tuple
from typing import Optional
import pytest
import torch
......@@ -45,7 +45,7 @@ MNK_SHAPES = [
(1024, 8192, 4096),
]
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1]
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
@dataclass
......@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype],
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TEST_TYPES = [
# GPTQ style
......@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: Tuple[int, int, int],
def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0
......@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_machete, w_s, w_zp
def create_test_tensors(shape: Tuple[int, int, int],
def create_test_tensors(shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors:
......@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = []
group_sizes: list[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
......@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = []
group_sizes: list[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
......
# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Tuple
import pytest
import torch
......@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: Tuple[int, int],
hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype,
device: str = 'cuda',
):
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Tuple
import pytest
import torch
import torch.nn.functional as F
......@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: Tuple[int, ...]):
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
......@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: Dict = {} # map: eg -> pointer to last taken sample
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
......
# SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product
from typing import Callable, Dict, List, Optional
from typing import Callable, Optional
import pytest
import torch
......@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
scaling_factors: list[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
......@@ -234,7 +234,7 @@ def test_rope_module_cache():
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {}
rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
......
......@@ -4,7 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
import importlib
from typing import Optional, Type
from typing import Optional
import pytest
import torch
......@@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
......
......@@ -4,9 +4,9 @@
import itertools
import random
import unittest
from collections.abc import Sequence
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Type, Union)
from typing import Any, NamedTuple, Optional, Union
import pytest
import torch
......@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
)
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
......@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_seq_lens: List[int]
kv_seq_lens: List[int]
q_seq_lens: list[int]
kv_seq_lens: list[int]
class QKVO(NamedTuple):
......@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
q_start_loc_list: Optional[List[int]]
kv_start_loc_list: Optional[List[int]]
q_seq_lens: Optional[List[int]]
kv_seq_lens: Optional[List[int]]
q_start_loc_list: Optional[list[int]]
kv_start_loc_list: Optional[list[int]]
q_seq_lens: Optional[list[int]]
kv_seq_lens: Optional[list[int]]
class PackedQKVO(NamedTuple):
......@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
def maybe_make_int_tensor(
_list: Optional[List[int]],
_list: Optional[list[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
......@@ -162,7 +162,7 @@ def maybe_make_int_tensor(
def maybe_make_long_tensor(
_list: Optional[List[int]],
_list: Optional[list[int]],
device: Union[torch.device, str],
) -> torch.Tensor:
'''
......@@ -177,7 +177,7 @@ def maybe_make_long_tensor(
_list, dtype=torch.long, device=device)
def maybe_max(_list: Optional[List]) -> Optional[Number]:
def maybe_max(_list: Optional[list]) -> Optional[Number]:
'''
Returns:
......@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
value: torch.Tensor,
scale: float,
custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
q_seq_lens: Optional[list] = None,
kv_seq_lens: Optional[list] = None) -> torch.Tensor:
'''
"Golden" masked attention reference. Supports two types of masking:
......@@ -295,10 +295,10 @@ def make_qkv(
num_heads: int,
head_size: int,
device: Union[torch.device, str],
force_kv_seq_lens: Optional[List[int]] = None,
force_kv_seq_lens: Optional[list[int]] = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
'''
Construct QKV test tensors for self- and cross-attention.
......@@ -429,8 +429,8 @@ def make_qkv(
def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: List[int],
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
unpacked_tensor: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
......@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
def _make_metadata_tensors(
seq_lens: Optional[List[int]],
context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]],
seq_lens: Optional[list[int]],
context_lens: Optional[list[int]],
encoder_seq_lens: Optional[list[int]],
device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
......@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return torch.tensor([], device=device)
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
......@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
......@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def make_block_tables_slot_mapping(
block_size: int,
seq_lens: List[int],
seq_lens: list[int],
device: Union[torch.device, str],
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
'''
Construct fake block tables & slot mappings.
......@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
seq_lens: Optional[list[int]],
decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None,
......@@ -1043,7 +1043,7 @@ def fp8_allclose(
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
......@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
cond: bool = True) -> dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck(
op,
......@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment