Unverified Commit 5e529522 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Add ruff config to check for useless spaces (#807)

* update lint config

* Remove spaces for blank line

* update
parent 4d54854b
...@@ -133,7 +133,7 @@ def run_autotune(M: int, N: int, K: int): ...@@ -133,7 +133,7 @@ def run_autotune(M: int, N: int, K: int):
def test_autotune_matmul(): def test_autotune_matmul():
""" """
Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem. Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem.
This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel, This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel,
executes it, and asserts the result matches a reference CPU implementation within tolerances. executes it, and asserts the result matches a reference CPU implementation within tolerances.
""" """
......
...@@ -55,4 +55,4 @@ def test_lower_hopper_intrin_barrier(): ...@@ -55,4 +55,4 @@ def test_lower_hopper_intrin_barrier():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
\ No newline at end of file
...@@ -118,4 +118,4 @@ def test_warp_specialized(): ...@@ -118,4 +118,4 @@ def test_warp_specialized():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
\ No newline at end of file
...@@ -713,7 +713,7 @@ def autotune( # This is the new public interface ...@@ -713,7 +713,7 @@ def autotune( # This is the new public interface
This decorator can be used without arguments (e.g., `@tilelang.jit`): This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings. Applies JIT compilation with default settings.
Tips: Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature. - If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python ```python
......
...@@ -78,7 +78,7 @@ __device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { ...@@ -78,7 +78,7 @@ __device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {
const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd; const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd;
const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width; const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width;
const auto bz = blockIdx.z; const auto bz = blockIdx.z;
dim3 blockIdx(bx, by, bz); dim3 blockIdx(bx, by, bz);
return blockIdx; return blockIdx;
} }
......
...@@ -6,4 +6,4 @@ from .gemv import GEMVTemplate # noqa: F401 ...@@ -6,4 +6,4 @@ from .gemv import GEMVTemplate # noqa: F401
from .elementwise import ElementwiseTemplate # noqa: F401 from .elementwise import ElementwiseTemplate # noqa: F401
from .general_reduce import GeneralReductionTemplate # noqa: F401 from .general_reduce import GeneralReductionTemplate # noqa: F401
from .flashattention import FlashAttentionTemplate # noqa: F401 from .flashattention import FlashAttentionTemplate # noqa: F401
from .conv import ConvTemplate # noqa: F401 from .conv import ConvTemplate # noqa: F401
\ No newline at end of file
...@@ -12,8 +12,8 @@ from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions ...@@ -12,8 +12,8 @@ from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
@dataclass @dataclass
class BaseTemplate(ABC): class BaseTemplate(ABC):
""" """
Base class template for hardware-aware configurations. Base class template for hardware-aware configurations.
This serves as an abstract base class (ABC) that defines the structure This serves as an abstract base class (ABC) that defines the structure
for subclasses implementing hardware-specific optimizations. for subclasses implementing hardware-specific optimizations.
""" """
...@@ -30,9 +30,9 @@ class BaseTemplate(ABC): ...@@ -30,9 +30,9 @@ class BaseTemplate(ABC):
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
""" """
Abstract method that must be implemented by subclasses. Abstract method that must be implemented by subclasses.
It should return a list of hardware-aware configurations (hints) It should return a list of hardware-aware configurations (hints)
based on the specified architecture. based on the specified architecture.
Args: Args:
arch (TileDevice, optional): The target architecture. Defaults to None. arch (TileDevice, optional): The target architecture. Defaults to None.
topk (int, optional): Number of top configurations to return. Defaults to 10. topk (int, optional): Number of top configurations to return. Defaults to 10.
...@@ -104,7 +104,7 @@ class BaseTemplate(ABC): ...@@ -104,7 +104,7 @@ class BaseTemplate(ABC):
""" """
Placeholder method that should be implemented by subclasses. Placeholder method that should be implemented by subclasses.
This method is responsible for initializing the function. This method is responsible for initializing the function.
Raises: Raises:
NotImplementedError: If not implemented in the subclass. NotImplementedError: If not implemented in the subclass.
""" """
......
...@@ -62,8 +62,8 @@ class ConvTemplate(BaseTemplate): ...@@ -62,8 +62,8 @@ class ConvTemplate(BaseTemplate):
""" """
Defines and initializes the convolution computation. Defines and initializes the convolution computation.
This method sets up placeholders for input matrices, computes This method sets up placeholders for input matrices, computes
the convolution using TVM's compute API, the convolution using TVM's compute API,
and optionally applies bias and type casting. and optionally applies bias and type casting.
Raises: Raises:
......
...@@ -44,8 +44,8 @@ class FlashAttentionTemplate(BaseTemplate): ...@@ -44,8 +44,8 @@ class FlashAttentionTemplate(BaseTemplate):
""" """
Defines and initializes the matrix multiplication computation. Defines and initializes the matrix multiplication computation.
This method sets up placeholders for input matrices, computes This method sets up placeholders for input matrices, computes
the matrix multiplication using TVM's compute API, the matrix multiplication using TVM's compute API,
and optionally applies bias and type casting. and optionally applies bias and type casting.
Raises: Raises:
......
...@@ -12,7 +12,7 @@ class GEMVTemplate(BaseTemplate): ...@@ -12,7 +12,7 @@ class GEMVTemplate(BaseTemplate):
""" """
A template for Generalized Matrix-Vector Multiplication (GEMV). A template for Generalized Matrix-Vector Multiplication (GEMV).
This template defines the computation for a matrix-vector multiplication This template defines the computation for a matrix-vector multiplication
with configurable parameters such as transposition, data types, and bias addition. with configurable parameters such as transposition, data types, and bias addition.
""" """
...@@ -43,8 +43,8 @@ class GEMVTemplate(BaseTemplate): ...@@ -43,8 +43,8 @@ class GEMVTemplate(BaseTemplate):
""" """
Defines and initializes the GEMV computation function. Defines and initializes the GEMV computation function.
This method sets up placeholders for input matrices, computes This method sets up placeholders for input matrices, computes
the matrix-vector multiplication using TVM's compute API, the matrix-vector multiplication using TVM's compute API,
and optionally applies bias and type casting. and optionally applies bias and type casting.
""" """
M: int = 1 # Fixed M value, representing a single batch dimension M: int = 1 # Fixed M value, representing a single batch dimension
......
...@@ -56,8 +56,8 @@ class MatmulTemplate(BaseTemplate): ...@@ -56,8 +56,8 @@ class MatmulTemplate(BaseTemplate):
""" """
Defines and initializes the matrix multiplication computation. Defines and initializes the matrix multiplication computation.
This method sets up placeholders for input matrices, computes This method sets up placeholders for input matrices, computes
the matrix multiplication using TVM's compute API, the matrix multiplication using TVM's compute API,
and optionally applies bias and type casting. and optionally applies bias and type casting.
Raises: Raises:
......
...@@ -126,7 +126,7 @@ def compile_cuda(code, ...@@ -126,7 +126,7 @@ def compile_cuda(code,
def find_cuda_path(): def find_cuda_path():
"""Utility function to find cuda path """Utility function to find cuda path
Returns Returns
------- -------
path : str path : str
......
...@@ -5,7 +5,7 @@ from tvm.target import Target ...@@ -5,7 +5,7 @@ from tvm.target import Target
def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = True): def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = True):
"""Register a post-processing function for CUDA code generation. """Register a post-processing function for CUDA code generation.
Args: Args:
func: A callable that takes generated code (str) and target (Target) as input, func: A callable that takes generated code (str) and target (Target) as input,
and returns the processed code (str). and returns the processed code (str).
...@@ -16,7 +16,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = ...@@ -16,7 +16,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True): def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
"""Register a post-processing function for HIP code generation. """Register a post-processing function for HIP code generation.
Args: Args:
func: A callable that takes generated code (str) and target (Target) as input, func: A callable that takes generated code (str) and target (Target) as input,
and returns the processed code (str). and returns the processed code (str).
...@@ -27,17 +27,17 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T ...@@ -27,17 +27,17 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function. """Decorator for registering CUDA post-processing callback function.
Can be used with or without parentheses: Can be used with or without parentheses:
@register_cuda_postproc_callback @register_cuda_postproc_callback
def func(code, target): ... def func(code, target): ...
@register_cuda_postproc_callback() @register_cuda_postproc_callback()
def func(code, target): ... def func(code, target): ...
@register_cuda_postproc_callback(override=False) @register_cuda_postproc_callback(override=False)
def func(code, target): ... def func(code, target): ...
Args: Args:
func: The function to be decorated or a boolean override flag func: The function to be decorated or a boolean override flag
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
...@@ -60,17 +60,17 @@ def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override ...@@ -60,17 +60,17 @@ def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override
def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
"""Decorator for registering HIP post-processing callback function. """Decorator for registering HIP post-processing callback function.
Can be used with or without parentheses: Can be used with or without parentheses:
@register_hip_postproc_callback @register_hip_postproc_callback
def func(code, target): ... def func(code, target): ...
@register_hip_postproc_callback() @register_hip_postproc_callback()
def func(code, target): ... def func(code, target): ...
@register_hip_postproc_callback(override=False) @register_hip_postproc_callback(override=False)
def func(code, target): ... def func(code, target): ...
Args: Args:
func: The function to be decorated or a boolean override flag func: The function to be decorated or a boolean override flag
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
......
...@@ -21,13 +21,13 @@ class KernelParam: ...@@ -21,13 +21,13 @@ class KernelParam:
def from_buffer(cls, buffer: Buffer): def from_buffer(cls, buffer: Buffer):
""" """
Creates a KernelParam instance from a TVM Buffer object. Creates a KernelParam instance from a TVM Buffer object.
Args: Args:
buffer: TVM Buffer object containing dtype and shape information buffer: TVM Buffer object containing dtype and shape information
Returns: Returns:
KernelParam instance with converted dtype and shape KernelParam instance with converted dtype and shape
Raises: Raises:
ValueError: If dimension type is not supported (not IntImm or Var) ValueError: If dimension type is not supported (not IntImm or Var)
""" """
...@@ -47,10 +47,10 @@ class KernelParam: ...@@ -47,10 +47,10 @@ class KernelParam:
""" """
Creates a KernelParam instance from a TVM Variable object. Creates a KernelParam instance from a TVM Variable object.
Used for scalar parameters. Used for scalar parameters.
Args: Args:
var: TVM Variable object containing dtype information var: TVM Variable object containing dtype information
Returns: Returns:
KernelParam instance representing a scalar (empty shape) KernelParam instance representing a scalar (empty shape)
""" """
...@@ -60,7 +60,7 @@ class KernelParam: ...@@ -60,7 +60,7 @@ class KernelParam:
def is_scalar(self) -> bool: def is_scalar(self) -> bool:
""" """
Checks if the parameter represents a scalar value. Checks if the parameter represents a scalar value.
Returns: Returns:
bool: True if parameter has no dimensions (empty shape), False otherwise bool: True if parameter has no dimensions (empty shape), False otherwise
""" """
...@@ -69,7 +69,7 @@ class KernelParam: ...@@ -69,7 +69,7 @@ class KernelParam:
def is_unsigned(self) -> bool: def is_unsigned(self) -> bool:
""" """
Checks if the parameter represents an unsigned integer type. Checks if the parameter represents an unsigned integer type.
Returns: Returns:
bool: True if parameter is an unsigned integer type, False otherwise bool: True if parameter is an unsigned integer type, False otherwise
""" """
...@@ -81,7 +81,7 @@ class KernelParam: ...@@ -81,7 +81,7 @@ class KernelParam:
def is_float8(self) -> bool: def is_float8(self) -> bool:
""" """
Checks if the parameter represents a float8 type. Checks if the parameter represents a float8 type.
Returns: Returns:
bool: True if parameter is a float8 type, False otherwise bool: True if parameter is a float8 type, False otherwise
""" """
...@@ -93,7 +93,7 @@ class KernelParam: ...@@ -93,7 +93,7 @@ class KernelParam:
def is_boolean(self) -> bool: def is_boolean(self) -> bool:
""" """
Checks if the parameter represents a boolean type. Checks if the parameter represents a boolean type.
Returns: Returns:
bool: True if parameter is a boolean type, False otherwise bool: True if parameter is a boolean type, False otherwise
""" """
......
...@@ -65,7 +65,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -65,7 +65,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
""" """
Bind target information and progressively legalize and lower frontend Tile IR into a form suitable for downstream optimization and codegen. Bind target information and progressively legalize and lower frontend Tile IR into a form suitable for downstream optimization and codegen.
This pass pipeline: This pass pipeline:
- Binds the provided target to the module. - Binds the provided target to the module.
- Legalizes frontend Tile IR into TVM-compatible constructs. - Legalizes frontend Tile IR into TVM-compatible constructs.
...@@ -75,11 +75,11 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -75,11 +75,11 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
- Legalizes vectorized loops and inserts safety checks for memory accesses. - Legalizes vectorized loops and inserts safety checks for memory accesses.
- Re-simplifies to remove redundancies introduced by safety checks. - Re-simplifies to remove redundancies introduced by safety checks.
- Attempts loop vectorization for dynamic-shaped loops. - Attempts loop vectorization for dynamic-shaped loops.
Parameters: Parameters:
mod (IRModule): The input IR module containing frontend Tile IR. mod (IRModule): The input IR module containing frontend Tile IR.
target (Target): Target device information to bind into the module. target (Target): Target device information to bind into the module.
Returns: Returns:
IRModule: The transformed module, ready for target-specific optimization passes. IRModule: The transformed module, ready for target-specific optimization passes.
""" """
......
...@@ -91,14 +91,14 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): ...@@ -91,14 +91,14 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# Basic Tensor Core Matrix Multiply operation Unit # Basic Tensor Core Matrix Multiply operation Unit
""" """
Return the MMA (Tensor Core) micro-tile dimensions for a given data type. Return the MMA (Tensor Core) micro-tile dimensions for a given data type.
This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations. This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
- x: tile width in the output/result dimension - x: tile width in the output/result dimension
- y: tile height in the output/result dimension - y: tile height in the output/result dimension
- k: tile depth in the reduction/K dimension - k: tile depth in the reduction/K dimension
Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16. Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.
Returns: Returns:
tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k) tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k)
""" """
......
""" """
This module provides an auto-tuning infrastructure for TileLang (tl) programs. This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM. kernel adapter using TVM.
""" """
......
...@@ -2,4 +2,4 @@ from .base import BaseKernelAdapter # noqa: F401 ...@@ -2,4 +2,4 @@ from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401 from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401
\ No newline at end of file
...@@ -16,7 +16,7 @@ from tilelang.utils.language import retrieve_func_from_module ...@@ -16,7 +16,7 @@ from tilelang.utils.language import retrieve_func_from_module
class CtypesKernelAdapter(BaseKernelAdapter): class CtypesKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes. """Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
This adapter handles: This adapter handles:
1. Converting TIR functions to compiled CUDA libraries 1. Converting TIR functions to compiled CUDA libraries
2. Managing dynamic shapes in tensor operations 2. Managing dynamic shapes in tensor operations
...@@ -52,7 +52,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -52,7 +52,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None): compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
params: List of tensor types for inputs/outputs params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors result_idx: Indices of output tensors
...@@ -157,7 +157,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -157,7 +157,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution. for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride id represents shape or stride, 0 represents shape, 1 represents stride
...@@ -184,7 +184,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -184,7 +184,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
""" """
ctypes_args = [ ctypes_args = [
...@@ -197,17 +197,17 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -197,17 +197,17 @@ class CtypesKernelAdapter(BaseKernelAdapter):
*ins: List[torch.Tensor], *ins: List[torch.Tensor],
stream: Optional[int] = None): stream: Optional[int] = None):
"""High-level wrapper for kernel execution. """High-level wrapper for kernel execution.
Handles: Handles:
1. Input validation 1. Input validation
2. Output tensor allocation 2. Output tensor allocation
3. Dynamic shape resolution 3. Dynamic shape resolution
4. CUDA stream management 4. CUDA stream management
Args: Args:
ins: Input PyTorch tensors ins: Input PyTorch tensors
stream: Optional CUDA stream for asynchronous execution stream: Optional CUDA stream for asynchronous execution
Returns: Returns:
Single tensor or list of tensors containing the kernel results Single tensor or list of tensors containing the kernel results
""" """
......
...@@ -176,7 +176,7 @@ from cython_wrapper import CythonKernelWrapper ...@@ -176,7 +176,7 @@ from cython_wrapper import CythonKernelWrapper
class CythonKernelAdapter(BaseKernelAdapter): class CythonKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes. """Adapter class that converts TVM/TIR functions to callable CUDA kernels using ctypes.
This adapter handles: This adapter handles:
1. Converting TIR functions to compiled CUDA libraries 1. Converting TIR functions to compiled CUDA libraries
2. Managing dynamic shapes in tensor operations 2. Managing dynamic shapes in tensor operations
...@@ -222,7 +222,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -222,7 +222,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None): compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
params: List of tensor types for inputs/outputs params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors result_idx: Indices of output tensors
...@@ -347,7 +347,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -347,7 +347,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
for runtime shape resolution. for runtime shape resolution.
id represents shape or stride, 0 represents shape, 1 represents stride id represents shape or stride, 0 represents shape, 1 represents stride
...@@ -374,7 +374,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -374,7 +374,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]: def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function. """Extract information about buffer dtypes from the TIR function.
Maps buffer variables to their corresponding dtypes. Maps buffer variables to their corresponding dtypes.
""" """
func = self.prim_func func = self.prim_func
...@@ -390,7 +390,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -390,7 +390,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _process_ptr_map(self) -> Dict[int, str]: def _process_ptr_map(self) -> Dict[int, str]:
"""Extract information about pointer arguments from the TIR function. """Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension) Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution. for runtime shape resolution.
""" """
...@@ -407,7 +407,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -407,7 +407,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]], Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
List[Tuple[tir.Var]]]: List[Tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function. """Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes. Maps buffer variables to their corresponding static shapes.
""" """
func = self.prim_func func = self.prim_func
...@@ -438,7 +438,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -438,7 +438,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]: def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function. """Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices. Maps buffer variables to their corresponding devices.
""" """
func = self.prim_func func = self.prim_func
...@@ -462,7 +462,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -462,7 +462,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
""" """
ctypes_args = [ ctypes_args = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment