"vscode:/vscode.git/clone" did not exist on "b4b78d63170ff0b1e5310c295473109d92ee51c2"
Unverified Commit cf069aa8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated Python 3.8 typing (#13971)

parent bf33700e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Tuple, Union from collections.abc import Iterable
from typing import Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, DeltaMessage,
...@@ -12,7 +13,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser ...@@ -12,7 +13,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser
class StreamingToolReconstructor: class StreamingToolReconstructor:
def __init__(self, assert_one_tool_per_delta: bool = True): def __init__(self, assert_one_tool_per_delta: bool = True):
self.tool_calls: List[ToolCall] = [] self.tool_calls: list[ToolCall] = []
self.other_content: str = "" self.other_content: str = ""
self._assert_one_tool_per_delta = assert_one_tool_per_delta self._assert_one_tool_per_delta = assert_one_tool_per_delta
...@@ -72,7 +73,7 @@ def run_tool_extraction( ...@@ -72,7 +73,7 @@ def run_tool_extraction(
request: Union[ChatCompletionRequest, None] = None, request: Union[ChatCompletionRequest, None] = None,
streaming: bool = False, streaming: bool = False,
assert_one_tool_per_delta: bool = True, assert_one_tool_per_delta: bool = True,
) -> Tuple[Union[str, None], List[ToolCall]]: ) -> tuple[Union[str, None], list[ToolCall]]:
if streaming: if streaming:
reconstructor = run_tool_extraction_streaming( reconstructor = run_tool_extraction_streaming(
tool_parser, tool_parser,
...@@ -106,7 +107,7 @@ def run_tool_extraction_streaming( ...@@ -106,7 +107,7 @@ def run_tool_extraction_streaming(
reconstructor = StreamingToolReconstructor( reconstructor = StreamingToolReconstructor(
assert_one_tool_per_delta=assert_one_tool_per_delta) assert_one_tool_per_delta=assert_one_tool_per_delta)
previous_text = "" previous_text = ""
previous_tokens: List[int] = [] previous_tokens: list[int] = []
for delta in model_deltas: for delta in model_deltas:
token_delta = [ token_delta = [
tool_parser.vocab.get(token) tool_parser.vocab.get(token)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
...@@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: ...@@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def ref_dynamic_per_token_quant(x: torch.tensor, def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \ scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]: -> tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE] assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None: if scale_ub is not None:
...@@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor, ...@@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel # kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]: -> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE) fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \ fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import Type
import pytest import pytest
import torch import torch
...@@ -86,7 +85,7 @@ def test_act_and_mul( ...@@ -86,7 +85,7 @@ def test_act_and_mul(
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_activation( def test_activation(
activation: Type[torch.nn.Module], activation: type[torch.nn.Module],
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention( ...@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i]) seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = [] keys_lst: list[torch.Tensor] = []
values_lst: List[torch.Tensor] = [] values_lst: list[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
...@@ -133,7 +133,7 @@ def test_paged_attention( ...@@ -133,7 +133,7 @@ def test_paged_attention(
kv_cache_factory, kv_cache_factory,
version: str, version: str,
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
use_alibi: bool, use_alibi: bool,
block_size: int, block_size: int,
...@@ -166,7 +166,7 @@ def test_paged_attention( ...@@ -166,7 +166,7 @@ def test_paged_attention(
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = [] block_tables_lst: list[list[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1)
...@@ -334,7 +334,7 @@ def test_paged_attention( ...@@ -334,7 +334,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: list[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
...@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention( ...@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs: List[torch.Tensor] = [] ref_outputs: list[torch.Tensor] = []
for i in range(num_seqs): for i in range(num_seqs):
start_idx = cu_seq_lens[i] start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1] end_idx = cu_seq_lens[i + 1]
...@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention( ...@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention( ...@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i]) seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = [] keys_lst: list[torch.Tensor] = []
values_lst: List[torch.Tensor] = [] values_lst: list[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
...@@ -162,7 +162,7 @@ def test_paged_attention( ...@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory, kv_cache_factory,
version: str, version: str,
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
use_alibi: bool, use_alibi: bool,
block_size: int, block_size: int,
...@@ -331,7 +331,7 @@ def test_paged_attention( ...@@ -331,7 +331,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: list[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
...@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention( ...@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode() @torch.inference_mode()
def test_varlen_blocksparse_attention_prefill( def test_varlen_blocksparse_attention_prefill(
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
blocksparse_local_blocks: int, blocksparse_local_blocks: int,
blocksparse_vert_stride: int, blocksparse_vert_stride: int,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Tuple
import pytest import pytest
import torch import torch
...@@ -74,7 +73,7 @@ def test_copy_blocks( ...@@ -74,7 +73,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping: List[Tuple[int, int]] = [] block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] src = src_blocks[i]
dst1 = dst_blocks[2 * i] dst1 = dst_blocks[2 * i]
...@@ -342,7 +341,7 @@ def test_reshape_and_cache_flash( ...@@ -342,7 +341,7 @@ def test_reshape_and_cache_flash(
@torch.inference_mode() @torch.inference_mode()
def test_swap_blocks( def test_swap_blocks(
kv_cache_factory, kv_cache_factory,
direction: Tuple[str, str], direction: tuple[str, str],
num_mappings: int, num_mappings: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16] ...@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@torch.inference_mode() @torch.inference_mode()
def test_merge_kernel( def test_merge_kernel(
num_tokens: int, num_tokens: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
): ):
...@@ -85,8 +85,8 @@ CASES = [ ...@@ -85,8 +85,8 @@ CASES = [
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_cascade( def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`. Run `pytest tests/kernels/test_cutlass.py`.
""" """
from typing import Type
import pytest import pytest
import torch import torch
...@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape: tuple, a_scale_group_shape: tuple,
b_scale_group_shape: tuple, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape: tuple, a_scale_group_shape: tuple,
b_scale_group_shape: tuple, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, ...@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape, b_scale_group_shape,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
cutlass_int8_gemm_helper(512, cutlass_int8_gemm_helper(512,
512, 512,
...@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, ...@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape, b_scale_group_shape,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
cutlass_fp8_gemm_helper(512, cutlass_fp8_gemm_helper(512,
512, 512,
...@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, ...@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
reason="FP8 blockwise is not supported on this GPU type.") reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape, b_scale_group_shape,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
cutlass_fp8_gemm_helper(512, cutlass_fp8_gemm_helper(512,
512, 512,
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`. Run `pytest tests/kernels/test_semi_structured.py`.
""" """
from typing import Tuple, Type
import pytest import pytest
import torch import torch
...@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, ...@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def make_rand_sparse_tensors( def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t() b = torch.randn((n, k), device='cuda').t()
...@@ -167,7 +166,7 @@ MNK_FACTORS = [ ...@@ -167,7 +166,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype], def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
# Create tensors # Create tensors
......
...@@ -243,7 +243,7 @@ def _decoder_attn_setup( ...@@ -243,7 +243,7 @@ def _decoder_attn_setup(
test_pt: TestPoint, test_pt: TestPoint,
test_rsrcs: TestResources, test_rsrcs: TestResources,
block_base_addr: int = 0, block_base_addr: int = 0,
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
''' '''
Set up test vectors & data structures for self-attention test. Set up test vectors & data structures for self-attention test.
...@@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt: TestPoint, test_pt: TestPoint,
test_rsrcs: TestResources, test_rsrcs: TestResources,
block_base_addr: int = 0, block_base_addr: int = 0,
) -> Tuple[PhaseTestParameters, PhaseTestParameters]: ) -> tuple[PhaseTestParameters, PhaseTestParameters]:
''' '''
Set up test vectors & data structures for cross-attention test. Set up test vectors & data structures for cross-attention test.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -24,8 +24,8 @@ def ref_paged_attn( ...@@ -24,8 +24,8 @@ def ref_paged_attn(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
query_lens: List[int], query_lens: list[int],
kv_lens: List[int], kv_lens: list[int],
block_tables: torch.Tensor, block_tables: torch.Tensor,
scale: float, scale: float,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
...@@ -35,7 +35,7 @@ def ref_paged_attn( ...@@ -35,7 +35,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy() block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape _, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = [] outputs: list[torch.Tensor] = []
start_idx = 0 start_idx = 0
for i in range(num_seqs): for i in range(num_seqs):
query_len = query_lens[i] query_len = query_lens[i]
...@@ -88,8 +88,8 @@ def ref_paged_attn( ...@@ -88,8 +88,8 @@ def ref_paged_attn(
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
use_out: bool, use_out: bool,
kv_lens: List[int], kv_lens: list[int],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
...@@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv( ...@@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv(
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
seq_lens: List[Tuple[int, int]], seq_lens: list[tuple[int, int]],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
sliding_window: Optional[int], sliding_window: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple from typing import Optional
import flashinfer import flashinfer
import pytest import pytest
...@@ -19,8 +19,8 @@ def ref_paged_attn( ...@@ -19,8 +19,8 @@ def ref_paged_attn(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
query_lens: List[int], query_lens: list[int],
kv_lens: List[int], kv_lens: list[int],
block_tables: torch.Tensor, block_tables: torch.Tensor,
scale: float, scale: float,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
...@@ -30,7 +30,7 @@ def ref_paged_attn( ...@@ -30,7 +30,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy() block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape _, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = [] outputs: list[torch.Tensor] = []
start_idx = 0 start_idx = 0
for i in range(num_seqs): for i in range(num_seqs):
query_len = query_lens[i] query_len = query_lens[i]
...@@ -78,8 +78,8 @@ def ref_paged_attn( ...@@ -78,8 +78,8 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_kv( def test_flashinfer_decode_with_paged_kv(
kv_lens: List[int], kv_lens: list[int],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
...@@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype, head_size: int, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float]) -> None: soft_cap: Optional[float]) -> None:
...@@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], ...@@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
def test_flashinfer_prefill_with_paged_fp8_kv( def test_flashinfer_prefill_with_paged_fp8_kv(
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int, head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None: soft_cap: Optional[float]) -> None:
pytest.skip("TODO: fix the accuracy issue") pytest.skip("TODO: fix the accuracy issue")
...@@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv( ...@@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_fp8_kv( def test_flashinfer_decode_with_paged_fp8_kv(
kv_lens: List[int], kv_lens: list[int],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union from typing import Optional, Union
import pytest import pytest
import torch import torch
...@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: ...@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def ref_rms_norm(rms_norm_layer: RMSNorm, def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor]) \ residual: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None: if residual is not None:
residual = residual.clone() residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual) out, residual = rms_norm_layer.forward_native(x, residual)
...@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, ...@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None: if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn assert quant_dtype == torch.float8_e4m3fn
...@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm, ...@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub) residual, scale_ub)
...@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor, ...@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None: if residual is not None:
residual = residual.clone() residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
...@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor, ...@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub) scale_ub)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from pathlib import Path from pathlib import Path
from typing import List
import pytest import pytest
import torch import torch
...@@ -16,7 +15,7 @@ GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") ...@@ -16,7 +15,7 @@ GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
def get_gguf_sample_tensors( def get_gguf_sample_tensors(
hidden_size: int, hidden_size: int,
quant_type: GGMLQuantizationType) -> List[ReaderTensor]: quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename sample_file = Path(sample_dir) / filename
......
...@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`. ...@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import math import math
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -45,7 +45,7 @@ MNK_SHAPES = [ ...@@ -45,7 +45,7 @@ MNK_SHAPES = [
(1024, 8192, 4096), (1024, 8192, 4096),
] ]
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
@dataclass @dataclass
...@@ -75,7 +75,7 @@ class Tensors: ...@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type) # Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point # NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type # None "Output Type" means the output type is the same as the act type
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool] Optional[torch.dtype], bool]
TEST_TYPES = [ TEST_TYPES = [
# GPTQ style # GPTQ style
...@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): ...@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype)) return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: Tuple[int, int, int], def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool: group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0 return group_size is None or group_size == -1 or group_size % shape[2] == 0
...@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype, ...@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_machete, w_s, w_zp return w_ref, w_q_machete, w_s, w_zp
def create_test_tensors(shape: Tuple[int, int, int], def create_test_tensors(shape: tuple[int, int, int],
types: TypeConfig, types: TypeConfig,
group_size: Optional[int], group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors: subset_stride_factor: Optional[int] = None) -> Tensors:
...@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig, ...@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@pytest.mark.parametrize("types", TEST_TYPES) @pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig): def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = [] group_sizes: list[Optional[int]] = []
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
...@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): ...@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids=lambda x: "x".join(str(v) for v in x)) ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES) @pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig): def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = [] group_sizes: list[Optional[int]] = []
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import unittest import unittest
from typing import Tuple
import pytest import pytest
import torch import torch
...@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables ...@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def test_mixer2_gated_norm_multi_gpu( def test_mixer2_gated_norm_multi_gpu(
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size_n_groups: Tuple[int, int], hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype, dtype: torch.dtype,
device: str = 'cuda', device: str = 'cuda',
): ):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, Tuple
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch, ...@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch # given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg, # e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc # 4 examples from second eg, etc
def get_continuous_batch(example_lens: Tuple[int, ...]): def get_continuous_batch(example_lens: tuple[int, ...]):
indices = [] indices = []
for i, x in enumerate(example_lens): for i, x in enumerate(example_lens):
...@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # example has been exhausted and needs to cycle
last_taken: Dict = {} # map: eg -> pointer to last taken sample last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product from itertools import accumulate, product
from typing import Callable, Dict, List, Optional from typing import Callable, Optional
import pytest import pytest
import torch import torch
...@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4] scaling_factors: list[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear", "rope_type": "linear",
"factor": tuple(scaling_factors) "factor": tuple(scaling_factors)
...@@ -234,7 +234,7 @@ def test_rope_module_cache(): ...@@ -234,7 +234,7 @@ def test_rope_module_cache():
}) })
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES) ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {} rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings): for setting in product(*settings):
head_size, rotary_dim, max_position, base, \ head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting is_neox_stype, rope_scaling, dtype = setting
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`. Run `pytest tests/kernels/test_triton_scaled_mm.py`.
""" """
import importlib import importlib
from typing import Optional, Type from typing import Optional
import pytest import pytest
import torch import torch
...@@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor, ...@@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out out = scale_a * out
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import itertools import itertools
import random import random
import unittest import unittest
from collections.abc import Sequence
from numbers import Number from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, from typing import Any, NamedTuple, Optional, Union
Type, Union)
import pytest import pytest
import torch import torch
...@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, ...@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
# For now, disable "test_aot_dispatch_dynamic" since there are some # For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4. # bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema", "test_schema",
"test_autograd_registration", "test_autograd_registration",
"test_faketensor", "test_faketensor",
) )
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
"test_schema", "test_schema",
"test_autograd_registration", "test_autograd_registration",
"test_faketensor", "test_faketensor",
...@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple): ...@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
query: torch.Tensor query: torch.Tensor
key: torch.Tensor key: torch.Tensor
value: torch.Tensor value: torch.Tensor
q_seq_lens: List[int] q_seq_lens: list[int]
kv_seq_lens: List[int] kv_seq_lens: list[int]
class QKVO(NamedTuple): class QKVO(NamedTuple):
...@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple): ...@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
query: torch.Tensor query: torch.Tensor
key: torch.Tensor key: torch.Tensor
value: torch.Tensor value: torch.Tensor
q_start_loc_list: Optional[List[int]] q_start_loc_list: Optional[list[int]]
kv_start_loc_list: Optional[List[int]] kv_start_loc_list: Optional[list[int]]
q_seq_lens: Optional[List[int]] q_seq_lens: Optional[list[int]]
kv_seq_lens: Optional[List[int]] kv_seq_lens: Optional[list[int]]
class PackedQKVO(NamedTuple): class PackedQKVO(NamedTuple):
...@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple): ...@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
def maybe_make_int_tensor( def maybe_make_int_tensor(
_list: Optional[List[int]], _list: Optional[list[int]],
device: Union[torch.device, str], device: Union[torch.device, str],
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
...@@ -162,7 +162,7 @@ def maybe_make_int_tensor( ...@@ -162,7 +162,7 @@ def maybe_make_int_tensor(
def maybe_make_long_tensor( def maybe_make_long_tensor(
_list: Optional[List[int]], _list: Optional[list[int]],
device: Union[torch.device, str], device: Union[torch.device, str],
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
...@@ -177,7 +177,7 @@ def maybe_make_long_tensor( ...@@ -177,7 +177,7 @@ def maybe_make_long_tensor(
_list, dtype=torch.long, device=device) _list, dtype=torch.long, device=device)
def maybe_max(_list: Optional[List]) -> Optional[Number]: def maybe_max(_list: Optional[list]) -> Optional[Number]:
''' '''
Returns: Returns:
...@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor, ...@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
custom_mask: Optional[torch.Tensor] = None, custom_mask: Optional[torch.Tensor] = None,
q_seq_lens: Optional[List] = None, q_seq_lens: Optional[list] = None,
kv_seq_lens: Optional[List] = None) -> torch.Tensor: kv_seq_lens: Optional[list] = None) -> torch.Tensor:
''' '''
"Golden" masked attention reference. Supports two types of masking: "Golden" masked attention reference. Supports two types of masking:
...@@ -295,10 +295,10 @@ def make_qkv( ...@@ -295,10 +295,10 @@ def make_qkv(
num_heads: int, num_heads: int,
head_size: int, head_size: int,
device: Union[torch.device, str], device: Union[torch.device, str],
force_kv_seq_lens: Optional[List[int]] = None, force_kv_seq_lens: Optional[list[int]] = None,
attn_type: AttentionType = AttentionType.ENCODER_DECODER, attn_type: AttentionType = AttentionType.ENCODER_DECODER,
force_max_len: bool = False, force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]: ) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
''' '''
Construct QKV test tensors for self- and cross-attention. Construct QKV test tensors for self- and cross-attention.
...@@ -429,8 +429,8 @@ def make_qkv( ...@@ -429,8 +429,8 @@ def make_qkv(
def pack_tensor( def pack_tensor(
unpacked_tensor: torch.Tensor, seq_lens: List[int], unpacked_tensor: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]: device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
''' '''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where unpadded number_of_tokens x num_heads x head_size tensor, where
...@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend: ...@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
def _make_metadata_tensors( def _make_metadata_tensors(
seq_lens: Optional[List[int]], seq_lens: Optional[list[int]],
context_lens: Optional[List[int]], context_lens: Optional[list[int]],
encoder_seq_lens: Optional[List[int]], encoder_seq_lens: Optional[list[int]],
device: Union[torch.device, str], device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]: torch.Tensor, torch.Tensor, Optional[int]]:
''' '''
Build scalar & tensor values required to build attention metadata structure. Build scalar & tensor values required to build attention metadata structure.
...@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]): ...@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return torch.tensor([], device=device) return torch.tensor([], device=device)
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
device: Union[torch.device, str]): device: Union[torch.device, str]):
''' '''
Split a slot mapping into valid prefill- and decode-phase slot mappings. Split a slot mapping into valid prefill- and decode-phase slot mappings.
...@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], ...@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
Arguments: Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N * slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
post-decode sequences post-decode sequences
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the * seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
description above) description above)
* device: cuda, cpu, etc. * device: cuda, cpu, etc.
...@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], ...@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def make_block_tables_slot_mapping( def make_block_tables_slot_mapping(
block_size: int, block_size: int,
seq_lens: List[int], seq_lens: list[int],
device: Union[torch.device, str], device: Union[torch.device, str],
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]: block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
''' '''
Construct fake block tables & slot mappings. Construct fake block tables & slot mappings.
...@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping( ...@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata( def make_test_metadata(
attn_backend: _Backend, attn_backend: _Backend,
is_prompt: bool, is_prompt: bool,
seq_lens: Optional[List[int]], seq_lens: Optional[list[int]],
decoder_test_params: Optional[PhaseTestParameters], decoder_test_params: Optional[PhaseTestParameters],
device: Union[torch.device, str], device: Union[torch.device, str],
encoder_test_params: Optional[PhaseTestParameters] = None, encoder_test_params: Optional[PhaseTestParameters] = None,
...@@ -1043,7 +1043,7 @@ def fp8_allclose( ...@@ -1043,7 +1043,7 @@ def fp8_allclose(
# Marlin MoE test utils # Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]): def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev) return torch.stack(tensors, dim=0).to(dev)
...@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk): ...@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
# and a patched version of allclose that supports fp8 types. # and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef], torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True, raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]: cond: bool = True) -> dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose): with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck( return torch.library.opcheck(
op, op,
...@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor, ...@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting # We treat N-dimensional group scaling as extended numpy-style broadcasting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment