Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
......@@ -46,11 +43,10 @@ def assert_gemm_codegen(
accum_dtype="float",
):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
print(func)
rt_mod, _ = tilelang.lower(func, target="webgpu")
artifact = tilelang.lower(func, target="webgpu")
src_code = rt_mod.imported_modules[0].get_source()
src_code = artifact.kernel_source
assert src_code is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import os
import ctypes
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""FFI APIs for tilelang"""
import tvm._ffi
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The auto-tune module for tilelang programs."""
import tilelang
......
......@@ -5,6 +5,7 @@ from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from .kernel_cache import KernelCache
from tilelang.env import TILELANG_CLEAR_CACHE
# Create singleton instance of KernelCache
_kernel_cache_instance = KernelCache()
......@@ -40,3 +41,7 @@ def clear_cache():
Clears the entire kernel cache (using KernelCache class).
"""
_kernel_cache_instance.clear_cache()
if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
clear_cache()
......@@ -4,20 +4,31 @@ import os
import json
import shutil
from hashlib import sha256
from typing import Callable, List, Literal, Union
from typing import Callable, List, Literal, Union, Optional
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang.engine.param import KernelParam
import threading
import cloudpickle
import logging
from tilelang.env import TILELANG_CACHE_DIR # noqa: F401
from tilelang.env import TILELANG_CACHE_DIR
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "warpped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
PARAMS_PATH = "params.pkl"
class KernelCache:
"""
Caches compiled kernels using a class and database persistence to avoid redundant compilation.
Cache files:
kernel.cu: The compiled kernel source code
warpped_kernel.cu: The compiled wrapped kernel source code
kernel_lib.so: The compiled kernel library
params.pkl: The compiled kernel parameters
"""
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
......@@ -126,17 +137,34 @@ class KernelCache:
cache_path = self._get_cache_path(key)
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save rt_mod as a str
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
try:
artifact_path = os.path.join(cache_path, "tvm_tmp_mod.txt")
with open(artifact_path, "w") as f:
f.write(kernel.rt_mod.imported_modules[0].get_source())
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.adapter.get_kernel_source())
except Exception as e:
self.logger.error(f"Error saving kernel module to disk: {e}")
self.logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
try:
dump_path = os.path.join(cache_path, "tvm_params.pkl")
with open(dump_path, "wb") as f:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
# Save kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
......@@ -155,31 +183,38 @@ class KernelCache:
cache_path = self._get_cache_path(key)
if not os.path.exists(cache_path):
return None
rt_module = None
rt_params = None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
try:
artifact_path = os.path.join(cache_path, "tvm_tmp_mod.txt")
with open(artifact_path, "r") as f:
rt_module = f.read()
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "r") as f:
kernel_global_source = f.read()
except Exception as e:
self.logger.error(f"Error loading kernel module from disk: {e}")
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
dump_path = os.path.join(cache_path, "tvm_params.pkl")
with open(dump_path, "rb") as f:
rt_params = cloudpickle.load(f)
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}")
if rt_module and rt_params:
return JITKernel(
rt_module_src=rt_module,
rt_params=rt_params,
execution_backend=execution_backend,
if kernel_global_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
target_host=target_host,
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
func=func,
)
else:
return None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Base infra"""
from .analysis import (
BlockInfo, # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Set, Union
from typing_extensions import Literal
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .arch_base import TileDevice
from .cuda import CUDA
from .cpu import CPU
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm
from tvm.target import Target
from .arch_base import TileDevice
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from dataclasses import dataclass
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode, OutputNode, Edge # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Benefit For Carver Schedule"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Hint definition for schedule"""
from tvm import DataType
from typing import Dict, List, Tuple
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""PrimFunc Wrapper and Block information Analaysis"""
import tvm
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .default import DefaultPolicy # noqa: F401
from .tensorcore import TensorCorePolicy # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List
import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment