Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for cuda core schedule""" """Policy for cuda core schedule"""
import functools import functools
import math import math
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for tensorcore schedule""" """Policy for tensorcore schedule"""
import tvm import tvm
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rasteration Plan For L2 Cache Locality""" """Rasteration Plan For L2 Cache Locality"""
from typing import List from typing import List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .tir import get_analyzer_by_tir # noqa: F401 from .tir import get_analyzer_by_tir # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List from typing import Dict, List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Tuple, Set, Mapping from typing import Dict, List, Tuple, Set, Mapping
from tvm.tir.schedule.schedule import BlockRV from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal from tvm.ir import structural_equal
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Template for the TileLang Carver.""" """Template for the TileLang Carver."""
from .base import BaseTemplate # noqa: F401 from .base import BaseTemplate # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules and classes # Import necessary modules and classes
from abc import ABC, abstractmethod # For defining abstract base classes from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes from dataclasses import dataclass, field # For defining data classes
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules # Import necessary modules
from dataclasses import dataclass # Used for defining data classes from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates from .base import BaseTemplate # Importing the base class for templates
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Optional, Union from typing import List, Optional, Union
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .transform_kind import TransformKind # noqa: F401 from .transform_kind import TransformKind # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copied from bitblas # Copied from bitblas
from enum import IntEnum from enum import IntEnum
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .nvcc import compile_cuda # noqa: F401 from .nvcc import compile_cuda # noqa: F401
from .hipcc import compile_hip # noqa: F401 from .hipcc import compile_hip # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Utility to invoke hipcc compiler in the system""" """Utility to invoke hipcc compiler in the system"""
# File is copied from a modified version of hipcc.py to support # File is copied from a modified version of hipcc.py to support
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name # pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py # modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system""" """Utility to invoke nvcc compiler in the system"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The compiler for TL programs.""" """The compiler for TL programs."""
import os import os
...@@ -10,7 +8,7 @@ from tvm import tir ...@@ -10,7 +8,7 @@ from tvm import tir
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.engine.phase import ( from tilelang.engine.phase import (
LowerAndLegalize, LowerAndLegalize,
...@@ -136,12 +134,76 @@ def canon_target_host(target: Union[str, Target], target_host: Optional[Union[st ...@@ -136,12 +134,76 @@ def canon_target_host(target: Union[str, Target], target_host: Optional[Union[st
return target_host return target_host
def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError(f"Target host {target_host.kind.name} is not supported")
return host_mod
def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
return device_mod
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip_without_compile")(
device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
return device_mod
def lower( def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None, target_host: Optional[Union[str, Target]] = None,
runtime_only=False, runtime_only=False,
): enable_host_codegen=False,
enable_device_compile=False,
) -> CompiledArtifact:
'''
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
'''
mod = func_or_mod mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
...@@ -167,56 +229,14 @@ def lower( ...@@ -167,56 +229,14 @@ def lower(
mod = OptimizeForTarget(mod, target) mod = OptimizeForTarget(mod, target)
host_mod = tir.transform.Filter(_is_host_call)(mod) host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
if is_cpu_device_backend(target):
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
else:
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError(f"Target host {target_host.kind.name} is not supported")
device_mod = tir.transform.Filter(_is_device_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda": codegen_mod = device_codegen(
# Debug comments to get the code device_mod, target) if enable_device_compile else device_codegen_without_compile(
# code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target) device_mod, target)
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
host_mod.import_module(device_mod) if enable_host_codegen:
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
if target_host.kind.name == "c": return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source())
# cpu host should be recompiled
# TODO(lei): this is a hack to make the C host backend work
temp_dir = tvm.contrib.utils.tempdir()
tmp_lib_path = temp_dir.relpath("tmp.so")
host_mod.export_library(tmp_lib_path)
host_mod = tvm.runtime.load_module(tmp_lib_path)
if runtime_only is True:
return host_mod
else:
return host_mod, params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment