Unverified Commit d764dca8 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement] Add compile_flags parameter to JIT kernel and adapter classes...


[Enhancement] Add compile_flags parameter to JIT kernel and adapter classes for improved compilation control (#656)

* [Enhancement] Add compile_flags parameter to JIT kernel and adapter classes for improved compilation control

* lint fix

* upd

* lint fix

* fix typo

* update typing

* update the use case of compile flags

* ci fix

* fix

* Fix CI workflow to correctly activate virtual environment from shared cache directory

* use local cache

* fix

* fix

* fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 267d9b3b
name: CI name: CI
on: [pull_request] on: [pull_request]
env: env:
PYTHON_VERSION: '3.9' PYTHON_VERSION: '3.9'
VENV_DIR: ${{ runner.tool_cache }}/tilelang_ci VENV_DIR: tilelang_ci
jobs: jobs:
format-check: format-check:
...@@ -21,34 +20,33 @@ jobs: ...@@ -21,34 +20,33 @@ jobs:
with: with:
python-version: ${{ env.PYTHON_VERSION }} python-version: ${{ env.PYTHON_VERSION }}
- name: Show CI Worker Info - name: Ensure venv (local & persistent)
run: echo "tool_cache=${{ runner.tool_cache }}"
- name: Cache virtual environment
id: cache-venv
uses: actions/cache@v4
with:
path: ${{ env.VENV_DIR }}
key: ${{ runner.os }}-py${{ env.PYTHON_VERSION }}-venv-${{ hashFiles('**/requirements-dev.txt', '**/requirements-test.txt') }}
- name: Create / ensure virtual environment
if: steps.cache-venv.outputs.cache-hit != 'true'
run: | run: |
python -m venv ${{ env.VENV_DIR }} set -e
source ${{ env.VENV_DIR }}/bin/activate REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
python -m pip install --upgrade pip --no-user MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [ -f requirements-test.txt ]; then
PIP_NO_BUILD_ISOLATION=1 \ if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
python -m pip install -r requirements-test.txt --no-user echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
# shellcheck source=/dev/null
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi fi
python -m pip install . --no-user
- name: Update submodules recursively - name: Update submodules
run: git submodule update --init --recursive run: git submodule update --init --recursive
- name: Run format check - name: Run format check
run: | run: |
source ${{ env.VENV_DIR }}/bin/activate source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
./format.sh ./format.sh
build-test: build-test:
...@@ -66,40 +64,41 @@ jobs: ...@@ -66,40 +64,41 @@ jobs:
with: with:
python-version: ${{ env.PYTHON_VERSION }} python-version: ${{ env.PYTHON_VERSION }}
- name: Cache virtual environment - name: Ensure venv (local & persistent)
id: cache-venv
uses: actions/cache@v4
with:
path: ${{ env.VENV_DIR }}
key: ${{ runner.os }}-py${{ env.PYTHON_VERSION }}-venv-${{ hashFiles('**/requirements-dev.txt', '**/requirements-test.txt') }}
- name: Create / ensure virtual environment
if: steps.cache-venv.outputs.cache-hit != 'true'
run: | run: |
python -m venv ${{ env.VENV_DIR }} set -e
source ${{ env.VENV_DIR }}/bin/activate REQS_HASH=$(cat requirements-test.txt 2>/dev/null || true)
python -m pip install --upgrade pip --no-user MARKER="${{ runner.tool_cache }}/.venv_marker_${{ env.PYTHON_VERSION }}_${REQS_HASH:0:8}"
if [ -f requirements-test.txt ]; then
PIP_NO_BUILD_ISOLATION=1 \ if [[ -f "$MARKER" ]] && [[ -f "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" ]]; then
python -m pip install -r requirements-test.txt --no-user echo "venv exists and hash matches – reuse it"
else
echo "venv stale or missing – recreating"
rm -rf "${{ runner.tool_cache }}/${{ env.VENV_DIR }}" "$MARKER"
python -m venv "${{ runner.tool_cache }}/${{ env.VENV_DIR }}"
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install --upgrade pip --no-user
[[ -f requirements-test.txt ]] && \
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
pip install . --no-user
touch "$MARKER"
fi fi
python -m pip install . --no-user
- name: Install project in wheel mode - name: Install project (wheel form)
run: | run: |
source ${{ env.VENV_DIR }}/bin/activate source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
python -m pip install . --no-user pip install . --no-user
- name: Run examples - name: Run examples
run: | run: |
source ${{ env.VENV_DIR }}/bin/activate source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd examples cd examples
unset PYTHONPATH unset PYTHONPATH
python -m pytest -n 4 **/test*.py python -m pytest -n 4 **/test*.py
- name: Run tests - name: Run tests
run: | run: |
source ${{ env.VENV_DIR }}/bin/activate source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python cd testing/python
unset PYTHONPATH unset PYTHONPATH
python -m pytest -n 4 python -m pytest -n 4
\ No newline at end of file
import tilelang
import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = jit_kernel(a, b)
print(c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
...@@ -20,6 +20,7 @@ def cached( ...@@ -20,6 +20,7 @@ def cached(
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython", execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None, pass_configs: Optional[dict] = None,
compile_flags: Optional[List[str]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels (using KernelCache class). Caches and reuses compiled kernels (using KernelCache class).
...@@ -33,7 +34,7 @@ def cached( ...@@ -33,7 +34,7 @@ def cached(
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) compile_flags=compile_flags)
def clear_cache(): def clear_cache():
......
...@@ -117,6 +117,7 @@ class KernelCache: ...@@ -117,6 +117,7 @@ class KernelCache:
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[List[str]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels to avoid redundant compilation. Caches and reuses compiled kernels to avoid redundant compilation.
...@@ -140,6 +141,7 @@ class KernelCache: ...@@ -140,6 +141,7 @@ class KernelCache:
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
key = self._generate_key( key = self._generate_key(
......
...@@ -37,6 +37,7 @@ def compile( ...@@ -37,6 +37,7 @@ def compile(
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
...@@ -66,7 +67,8 @@ def compile( ...@@ -66,7 +67,8 @@ def compile(
"tl.disable_safe_memory_legalize": bool, default: False "tl.disable_safe_memory_legalize": bool, default: False
""" """
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
return cached( return cached(
func=func, func=func,
out_idx=out_idx, out_idx=out_idx,
...@@ -75,6 +77,7 @@ def compile( ...@@ -75,6 +77,7 @@ def compile(
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
...@@ -87,6 +90,7 @@ class _JitImplementation: ...@@ -87,6 +90,7 @@ class _JitImplementation:
verbose: bool verbose: bool
pass_configs: Optional[Dict[str, Any]] pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str] debug_root_path: Optional[str]
compile_flags: Optional[List[str]]
def __init__(self, def __init__(self,
out_idx: Any = None, out_idx: Any = None,
...@@ -95,7 +99,8 @@ class _JitImplementation: ...@@ -95,7 +99,8 @@ class _JitImplementation:
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None): debug_root_path: Optional[str] = None,
compile_flags: Optional[List[str]] = None):
""" """
Initializes the JIT compiler decorator. Initializes the JIT compiler decorator.
...@@ -134,6 +139,7 @@ class _JitImplementation: ...@@ -134,6 +139,7 @@ class _JitImplementation:
self.target_host = target_host self.target_host = target_host
self.verbose = verbose self.verbose = verbose
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.compile_flags = compile_flags
# Corrected debug_root_path handling # Corrected debug_root_path handling
self.debug_root_path = debug_root_path self.debug_root_path = debug_root_path
...@@ -176,6 +182,7 @@ class _JitImplementation: ...@@ -176,6 +182,7 @@ class _JitImplementation:
'target_host': self.target_host, 'target_host': self.target_host,
'verbose': self.verbose, 'verbose': self.verbose,
'pass_configs': self.pass_configs, 'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
} }
return compile_args return compile_args
...@@ -202,6 +209,7 @@ class _JitImplementation: ...@@ -202,6 +209,7 @@ class _JitImplementation:
target_host=self.target_host, target_host=self.target_host,
verbose=self.verbose, verbose=self.verbose,
pass_configs=self.pass_configs, pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
) )
if self.debug_root_path: if self.debug_root_path:
...@@ -230,7 +238,8 @@ def jit( # This is the new public interface ...@@ -230,7 +238,8 @@ def jit( # This is the new public interface
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None): debug_root_path: Optional[str] = None,
compile_flags: Optional[Union[List[str], str]] = None):
""" """
Just-In-Time (JIT) compiler decorator for TileLang functions. Just-In-Time (JIT) compiler decorator for TileLang functions.
...@@ -262,6 +271,9 @@ def jit( # This is the new public interface ...@@ -262,6 +271,9 @@ def jit( # This is the new public interface
Either a JIT-compiled wrapper around the input function, or a configured decorator Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function. instance that can then be applied to a function.
""" """
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if callable(func): if callable(func):
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults) # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
# Create a default _JitImplementation instance and apply it to the function. # Create a default _JitImplementation instance and apply it to the function.
...@@ -272,7 +284,8 @@ def jit( # This is the new public interface ...@@ -272,7 +284,8 @@ def jit( # This is the new public interface
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path) debug_root_path=debug_root_path,
compile_flags=compile_flags)
return default_decorator(func) return default_decorator(func)
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
...@@ -287,5 +300,6 @@ def jit( # This is the new public interface ...@@ -287,5 +300,6 @@ def jit( # This is the new public interface
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path) debug_root_path=debug_root_path,
compile_flags=compile_flags)
return configured_decorator return configured_decorator
...@@ -49,7 +49,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -49,7 +49,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None, device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None, kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -89,6 +90,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -89,6 +90,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.wrapper = TLWrapper(self.target) self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target) self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs) self.lib_generator.assign_pass_configs(pass_configs)
self.lib_generator.assign_compile_flags(compile_flags)
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
...@@ -112,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -112,7 +114,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -145,6 +148,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -145,6 +148,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target) adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs) adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.init() adapter.lib.init()
......
...@@ -214,7 +214,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -214,7 +214,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None, device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None, kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -245,6 +246,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -245,6 +246,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.wrapper = TLWrapper(self.target) self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target) self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs) self.lib_generator.assign_pass_configs(pass_configs)
self.lib_generator.assign_compile_flags(compile_flags)
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
...@@ -280,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -280,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -305,6 +308,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -305,6 +308,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target) adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs) adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.get_last_error.restype = ctypes.c_char_p adapter.lib.get_last_error.restype = ctypes.c_char_p
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import os.path as osp import os.path as osp
import subprocess import subprocess
import tempfile import tempfile
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, List
from tvm.target import Target from tvm.target import Target
...@@ -36,6 +36,7 @@ class LibraryGenerator(object): ...@@ -36,6 +36,7 @@ class LibraryGenerator(object):
libpath: Optional[str] = None libpath: Optional[str] = None
lib_code: Optional[str] = None lib_code: Optional[str] = None
pass_configs: Optional[Dict[str, Any]] = None pass_configs: Optional[Dict[str, Any]] = None
compile_flags: Optional[List[str]] = None
def __init__(self, target: Target): def __init__(self, target: Target):
self.target = target self.target = target
...@@ -43,6 +44,11 @@ class LibraryGenerator(object): ...@@ -43,6 +44,11 @@ class LibraryGenerator(object):
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None): def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
self.pass_configs = pass_configs self.pass_configs = pass_configs
def assign_compile_flags(self, compile_flags: Optional[List[str]] = None):
if compile_flags is None:
compile_flags = []
self.compile_flags = compile_flags
def update_lib_code(self, lib_code: str): def update_lib_code(self, lib_code: str):
self.lib_code = lib_code self.lib_code = lib_code
...@@ -75,7 +81,7 @@ class LibraryGenerator(object): ...@@ -75,7 +81,7 @@ class LibraryGenerator(object):
"-Xcudafe", "-Xcudafe",
"--diag_suppress=177", "--diag_suppress=177",
"--compiler-options", "--compiler-options",
"'-fPIC'", "-fPIC",
"-lineinfo", "-lineinfo",
"--shared", "--shared",
src.name, src.name,
...@@ -125,6 +131,12 @@ class LibraryGenerator(object): ...@@ -125,6 +131,12 @@ class LibraryGenerator(object):
command += [ command += [
"-I" + TILELANG_TEMPLATE_PATH, "-I" + TILELANG_TEMPLATE_PATH,
] ]
if self.compile_flags:
command += [
item for flag in self.compile_flags for item in flag.split() if item not in command
]
command += ["-o", libpath] command += ["-o", libpath]
src.write(self.lib_code) src.write(self.lib_code)
...@@ -217,11 +229,15 @@ class PyLibraryGenerator(LibraryGenerator): ...@@ -217,11 +229,15 @@ class PyLibraryGenerator(LibraryGenerator):
cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags:
options += [
item for flag in self.compile_flags for item in flag.split()
if item not in options
]
cubin_bytes = compile_cuda( cubin_bytes = compile_cuda(
self.lib_code, self.lib_code, target_format="cubin", options=options, verbose=True)
target_format="cubin",
options=[f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"],
verbose=True)
with open(libpath, "wb") as f: with open(libpath, "wb") as f:
f.write(cubin_bytes) f.write(cubin_bytes)
......
...@@ -40,7 +40,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -40,7 +40,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
device_mod: Optional[tvm.IRModule] = None, device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None, kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
if not is_nvrtc_available: if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING) raise ImportError(NVRTC_UNAVAILABLE_WARNING)
...@@ -83,6 +84,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -83,6 +84,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.lib_generator = PyLibraryGenerator(self.target) self.lib_generator = PyLibraryGenerator(self.target)
self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func) self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags)
self.lib_generator.compile_lib() self.lib_generator.compile_lib()
self.lib_generator.load_lib() self.lib_generator.load_lib()
self.libpath = self.lib_generator.libpath self.libpath = self.lib_generator.libpath
......
...@@ -45,6 +45,7 @@ class JITKernel(object): ...@@ -45,6 +45,7 @@ class JITKernel(object):
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
from_database: bool = False, from_database: bool = False,
compile_flags: Optional[List[str]] = None,
): ):
""" """
Initializes a TorchFunction instance. Initializes a TorchFunction instance.
...@@ -82,6 +83,8 @@ class JITKernel(object): ...@@ -82,6 +83,8 @@ class JITKernel(object):
pass_configs = {} pass_configs = {}
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.compile_flags = compile_flags
# If the target is specified as a string, validate it and convert it to a TVM Target. # If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str): if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}" assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
...@@ -126,6 +129,7 @@ class JITKernel(object): ...@@ -126,6 +129,7 @@ class JITKernel(object):
out_idx: Union[List[int], int], out_idx: Union[List[int], int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"], execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None,
): ):
""" """
Alternative constructor to create a TorchFunction directly from a database. Alternative constructor to create a TorchFunction directly from a database.
...@@ -138,6 +142,7 @@ class JITKernel(object): ...@@ -138,6 +142,7 @@ class JITKernel(object):
target_host=target_host, target_host=target_host,
pass_configs=pass_configs, pass_configs=pass_configs,
from_database=True, from_database=True,
compile_flags=compile_flags,
) )
instance.adapter = instance._create_adapter_from_database( instance.adapter = instance._create_adapter_from_database(
...@@ -148,6 +153,7 @@ class JITKernel(object): ...@@ -148,6 +153,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source, kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
instance.torch_function = instance.adapter.func instance.torch_function = instance.adapter.func
return instance return instance
...@@ -192,6 +198,8 @@ class JITKernel(object): ...@@ -192,6 +198,8 @@ class JITKernel(object):
execution_backend = self.execution_backend execution_backend = self.execution_backend
pass_configs = self.pass_configs pass_configs = self.pass_configs
compile_flags = self.compile_flags
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "dlpack" enable_host_codegen = execution_backend == "dlpack"
enable_device_compile = execution_backend == "dlpack" enable_device_compile = execution_backend == "dlpack"
...@@ -224,6 +232,7 @@ class JITKernel(object): ...@@ -224,6 +232,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source, kernel_global_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
elif execution_backend == "cython": elif execution_backend == "cython":
adapter = CythonKernelAdapter( adapter = CythonKernelAdapter(
...@@ -236,6 +245,7 @@ class JITKernel(object): ...@@ -236,6 +245,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source, kernel_global_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
elif execution_backend == "nvrtc": elif execution_backend == "nvrtc":
adapter = NVRTCKernelAdapter( adapter = NVRTCKernelAdapter(
...@@ -248,6 +258,7 @@ class JITKernel(object): ...@@ -248,6 +258,7 @@ class JITKernel(object):
kernel_global_source=artifact.kernel_source, kernel_global_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
else: else:
# Handle invalid backend. # Handle invalid backend.
...@@ -256,15 +267,15 @@ class JITKernel(object): ...@@ -256,15 +267,15 @@ class JITKernel(object):
return adapter return adapter
def _create_adapter_from_database( def _create_adapter_from_database(
self, self,
params: List[KernelParam], params: List[KernelParam],
result_idx: Union[List[int], int], result_idx: Union[List[int], int],
target: Union[str, Target], target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module], func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
) -> BaseKernelAdapter: compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter:
target = self.target target = self.target
execution_backend = self.execution_backend execution_backend = self.execution_backend
...@@ -280,6 +291,7 @@ class JITKernel(object): ...@@ -280,6 +291,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source, kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
elif execution_backend == "cython": elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database( adapter = CythonKernelAdapter.from_database(
...@@ -300,6 +312,7 @@ class JITKernel(object): ...@@ -300,6 +312,7 @@ class JITKernel(object):
kernel_global_source=kernel_global_source, kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
else: else:
# Handle invalid backend. # Handle invalid backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment