"src/vscode:/vscode.git/clone" did not exist on "f4a828f6ba004f4d1165e6d46ac8b42e25f736fd"
Unverified Commit d764dca8 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement] Add compile_flags parameter to JIT kernel and adapter classes...


[Enhancement] Add compile_flags parameter to JIT kernel and adapter classes for improved compilation control (#656)

* [Enhancement] Add compile_flags parameter to JIT kernel and adapter classes for improved compilation control

* lint fix

* upd

* lint fix

* fix typo

* update typing

* update the use case of compile flags

* ci fix

* fix

* Fix CI workflow to correctly activate virtual environment from shared cache directory

* use local cache

* fix

* fix

* fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 267d9b3b
name: CI
on: [pull_request]
env:
PYTHON_VERSION: '3.9'
VENV_DIR: ${{ runner.tool_cache }}/tilelang_ci
VENV_DIR: tilelang_ci
jobs:
format-check:
......@@ -21,34 +20,33 @@ jobs:
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Show CI Worker Info
run: echo "tool_cache=${{ runner.tool_cache }}"
- name: Cache virtual environment
id: cache-venv
uses: actions/cache@v4
with:
path: ${{ env.VENV_DIR }}
key: ${{ runner.os }}-py${{ env.PYTHON_VERSION }}-venv-${{ hashFiles('**/requirements-dev.txt', '**/requirements-test.txt') }}
- name: Create / ensure virtual environment
if: steps.cache-venv.outputs.cache-hit != 'true'
- name: Ensure venv (local & persistent)
run: |
python -m venv ${{ env.VENV_DIR }}
source ${{ env.VENV_DIR }}/bin/activate
python -m pip install --upgrade pip --no-user
if [ -f requirements-test.txt ]; then
PIP_NO_BUILD_ISOLATION=1 \
python -m pip install -r requirements-test.txt --no-user
set -e
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
# shellcheck source=/dev/null
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi
python -m pip install . --no-user
- name: Update submodules recursively
- name: Update submodules
run: git submodule update --init --recursive
- name: Run format check
run: |
source ${{ env.VENV_DIR }}/bin/activate
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
./format.sh
build-test:
......@@ -66,40 +64,41 @@ jobs:
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Cache virtual environment
id: cache-venv
uses: actions/cache@v4
with:
path: ${{ env.VENV_DIR }}
key: ${{ runner.os }}-py${{ env.PYTHON_VERSION }}-venv-${{ hashFiles('**/requirements-dev.txt', '**/requirements-test.txt') }}
- name: Create / ensure virtual environment
if: steps.cache-venv.outputs.cache-hit != 'true'
- name: Ensure venv (local & persistent)
run: |
python -m venv ${{ env.VENV_DIR }}
source ${{ env.VENV_DIR }}/bin/activate
python -m pip install --upgrade pip --no-user
if [ -f requirements-test.txt ]; then
PIP_NO_BUILD_ISOLATION=1 \
python -m pip install -r requirements-test.txt --no-user
set -e
REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi
python -m pip install . --no-user
- name: Install project in wheel mode
- name: Install project (wheel form)
run: |
source ${{ env.VENV_DIR }}/bin/activate
python -m pip install . --no-user
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
pip install . --no-user
- name: Run examples
run: |
source ${{ env.VENV_DIR }}/bin/activate
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd examples
unset PYTHONPATH
python -m pytest -n 4 **/test*.py
- name: Run tests
run: |
source ${{ env.VENV_DIR }}/bin/activate
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python
unset PYTHONPATH
python -m pytest -n 4
python -m pytest -n 4
\ No newline at end of file
import tilelang
import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = jit_kernel(a, b)
print(c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
......@@ -20,6 +20,7 @@ def cached(
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
compile_flags: Optional[List[str]] = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels (using KernelCache class).
......@@ -33,7 +34,7 @@ def cached(
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
)
compile_flags=compile_flags)
def clear_cache():
......
......@@ -117,6 +117,7 @@ class KernelCache:
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: Optional[List[str]] = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
......@@ -140,6 +141,7 @@ class KernelCache:
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
key = self._generate_key(
......
......@@ -37,6 +37,7 @@ def compile(
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
......@@ -66,7 +67,8 @@ def compile(
"tl.disable_safe_memory_legalize": bool, default: False
"""
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
return cached(
func=func,
out_idx=out_idx,
......@@ -75,6 +77,7 @@ def compile(
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
......@@ -87,6 +90,7 @@ class _JitImplementation:
verbose: bool
pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str]
compile_flags: Optional[List[str]]
def __init__(self,
out_idx: Any = None,
......@@ -95,7 +99,8 @@ class _JitImplementation:
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None):
debug_root_path: Optional[str] = None,
compile_flags: Optional[List[str]] = None):
"""
Initializes the JIT compiler decorator.
......@@ -134,6 +139,7 @@ class _JitImplementation:
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.compile_flags = compile_flags
# Corrected debug_root_path handling
self.debug_root_path = debug_root_path
......@@ -176,6 +182,7 @@ class _JitImplementation:
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
}
return compile_args
......@@ -202,6 +209,7 @@ class _JitImplementation:
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
)
if self.debug_root_path:
......@@ -230,7 +238,8 @@ def jit( # This is the new public interface
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None):
debug_root_path: Optional[str] = None,
compile_flags: Optional[Union[List[str], str]] = None):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
......@@ -262,6 +271,9 @@ def jit( # This is the new public interface
Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function.
"""
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if callable(func):
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
# Create a default _JitImplementation instance and apply it to the function.
......@@ -272,7 +284,8 @@ def jit( # This is the new public interface
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path)
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
......@@ -287,5 +300,6 @@ def jit( # This is the new public interface
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path)
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return configured_decorator
......@@ -49,7 +49,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -89,6 +90,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)
self.lib_generator.assign_compile_flags(compile_flags)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
......@@ -112,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -145,6 +148,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.init()
......
......@@ -214,7 +214,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -245,6 +246,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)
self.lib_generator.assign_compile_flags(compile_flags)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
......@@ -280,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -305,6 +308,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.get_last_error.restype = ctypes.c_char_p
......
......@@ -5,7 +5,7 @@ import os
import os.path as osp
import subprocess
import tempfile
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from tvm.target import Target
......@@ -36,6 +36,7 @@ class LibraryGenerator(object):
libpath: Optional[str] = None
lib_code: Optional[str] = None
pass_configs: Optional[Dict[str, Any]] = None
compile_flags: Optional[List[str]] = None
def __init__(self, target: Target):
self.target = target
......@@ -43,6 +44,11 @@ class LibraryGenerator(object):
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
self.pass_configs = pass_configs
def assign_compile_flags(self, compile_flags: Optional[List[str]] = None):
if compile_flags is None:
compile_flags = []
self.compile_flags = compile_flags
def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
......@@ -75,7 +81,7 @@ class LibraryGenerator(object):
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
"'-fPIC'",
"-fPIC",
"-lineinfo",
"--shared",
src.name,
......@@ -125,6 +131,12 @@ class LibraryGenerator(object):
command += [
"-I" + TILELANG_TEMPLATE_PATH,
]
if self.compile_flags:
command += [
item for flag in self.compile_flags for item in flag.split() if item not in command
]
command += ["-o", libpath]
src.write(self.lib_code)
......@@ -217,11 +229,15 @@ class PyLibraryGenerator(LibraryGenerator):
cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda(
self.lib_code,
target_format="cubin",
options=[f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"],
verbose=True)
self.lib_code, target_format="cubin", options=options, verbose=True)
with open(libpath, "wb") as f:
f.write(cubin_bytes)
......
......@@ -40,7 +40,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
......@@ -83,6 +84,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.lib_generator = PyLibraryGenerator(self.target)
self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags)
self.lib_generator.compile_lib()
self.lib_generator.load_lib()
self.libpath = self.lib_generator.libpath
......
......@@ -45,6 +45,7 @@ class JITKernel(object):
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
from_database: bool = False,
compile_flags: Optional[List[str]] = None,
):
"""
Initializes a TorchFunction instance.
......@@ -82,6 +83,8 @@ class JITKernel(object):
pass_configs = {}
self.pass_configs = pass_configs
self.compile_flags = compile_flags
# If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
......@@ -126,6 +129,7 @@ class JITKernel(object):
out_idx: Union[List[int], int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None,
):
"""
Alternative constructor to create a TorchFunction directly from a database.
......@@ -138,6 +142,7 @@ class JITKernel(object):
target_host=target_host,
pass_configs=pass_configs,
from_database=True,
compile_flags=compile_flags,
)
instance.adapter = instance._create_adapter_from_database(
......@@ -148,6 +153,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
instance.torch_function = instance.adapter.func
return instance
......@@ -192,6 +198,8 @@ class JITKernel(object):
execution_backend = self.execution_backend
pass_configs = self.pass_configs
compile_flags = self.compile_flags
# Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "dlpack"
enable_device_compile = execution_backend == "dlpack"
......@@ -224,6 +232,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter(
......@@ -236,6 +245,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "nvrtc":
adapter = NVRTCKernelAdapter(
......@@ -248,6 +258,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else:
# Handle invalid backend.
......@@ -256,15 +267,15 @@ class JITKernel(object):
return adapter
def _create_adapter_from_database(
self,
params: List[KernelParam],
result_idx: Union[List[int], int],
target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
pass_configs: Optional[Dict[str, Any]] = None,
) -> BaseKernelAdapter:
self,
params: List[KernelParam],
result_idx: Union[List[int], int],
target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
......@@ -280,6 +291,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
......@@ -300,6 +312,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else:
# Handle invalid backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment