"test/validate.cpp" did not exist on "dd465fab2cdbea287fc59f646166db2b180b5b58"
Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for cuda core schedule"""
import functools
import math
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Policy for tensorcore schedule"""
import tvm
from typing import Dict, List, Tuple, Optional
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rasteration Plan For L2 Cache Locality"""
from typing import List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .tir import get_analyzer_by_tir # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import OrderedDict
from typing import Dict, List
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Tuple, Set, Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Template for the TileLang Carver."""
from .base import BaseTemplate # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules and classes
from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Import necessary modules
from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Optional, Union
from tvm import tir, IRModule
from tvm.tir import PrimFunc
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .transform_kind import TransformKind # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copied from bitblas
from enum import IntEnum
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .nvcc import compile_cuda # noqa: F401
from .hipcc import compile_hip # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name
"""Utility to invoke hipcc compiler in the system"""
# File is copied from a modified version of hipcc.py to support
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=invalid-name
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The compiler for TL programs."""
import os
......@@ -10,7 +8,7 @@ from tvm import tir
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.engine.param import KernelParam
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target
from tilelang.engine.phase import (
LowerAndLegalize,
......@@ -136,12 +134,76 @@ def canon_target_host(target: Union[str, Target], target_host: Optional[Union[st
return target_host
def host_codegen(host_mod: tvm.IRModule, target_host: Target) -> tvm.IRModule:
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError(f"Target host {target_host.kind.name} is not supported")
return host_mod
def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
return device_mod
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda_without_compile")(
device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip_without_compile")(
device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
return device_mod
def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
runtime_only=False,
):
enable_host_codegen=False,
enable_device_compile=False,
) -> CompiledArtifact:
'''
enable_host_codegen: whether to enable host codegen, default is False, as we have our
own host codegen implementation in jit.
enable_device_compile: whether to enable device codegen, default is False, as we have our
own device codegen implementation in jit.
'''
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
......@@ -167,56 +229,14 @@ def lower(
mod = OptimizeForTarget(mod, target)
host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
host_mod = tir.transform.CombineContextCall()(host_mod)
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
if is_cpu_device_backend(target):
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
else:
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError(f"Target host {target_host.kind.name} is not supported")
device_mod = tir.transform.Filter(_is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
if target.kind.name == "cuda":
# Debug comments to get the code
# code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target)
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError(f"Target {target.kind.name} is not supported")
codegen_mod = device_codegen(
device_mod, target) if enable_device_compile else device_codegen_without_compile(
device_mod, target)
host_mod.import_module(device_mod)
if enable_host_codegen:
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
if target_host.kind.name == "c":
# cpu host should be recompiled
# TODO(lei): this is a hack to make the C host backend work
temp_dir = tvm.contrib.utils.tempdir()
tmp_lib_path = temp_dir.relpath("tmp.so")
host_mod.export_library(tmp_lib_path)
host_mod = tvm.runtime.load_module(tmp_lib_path)
if runtime_only is True:
return host_mod
else:
return host_mod, params
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment