Commit 029da5e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update List

parent 09396f62
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union, List
import torch import torch
...@@ -34,7 +34,7 @@ def cutlass_scaled_mm( ...@@ -34,7 +34,7 @@ def cutlass_scaled_mm(
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
return ops.cutlass_scaled_mm(A, return ops.cutlass_scaled_mm(A,
...@@ -49,7 +49,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( ...@@ -49,7 +49,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
import aiter as rocm_aiter import aiter as rocm_aiter
...@@ -62,7 +62,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( ...@@ -62,7 +62,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -89,7 +89,7 @@ def dispatch_w8a8_blockscale_func( ...@@ -89,7 +89,7 @@ def dispatch_w8a8_blockscale_func(
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
list[int], List[int],
torch.dtype, torch.dtype,
], torch.Tensor]: ], torch.Tensor]:
if use_cutlass: if use_cutlass:
...@@ -117,7 +117,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): ...@@ -117,7 +117,7 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
def apply_w8a8_block_fp8_linear( def apply_w8a8_block_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
block_size: list[int], block_size: List[int],
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
...@@ -190,7 +190,7 @@ def apply_w8a8_block_fp8_linear( ...@@ -190,7 +190,7 @@ def apply_w8a8_block_fp8_linear(
def apply_w8a8_block_fp8_linear_fake( def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
block_size: list[int], block_size: List[int],
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
...@@ -571,7 +571,7 @@ def w8a8_block_fp8_matmul( ...@@ -571,7 +571,7 @@ def w8a8_block_fp8_matmul(
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise """This function performs matrix multiplication with block-wise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment