Unverified Commit fba12a5f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Convey `compile_flags` to ffi compilation path with pass_configs (#1434)

* [Enhancement] Add device compile flags support in pass configuration

* Introduced `kDeviceCompileFlags` option in the pass configuration to allow additional device compiler flags for CUDA compilation.
* Updated the `tilelang_callback_cuda_compile` function to merge extra flags from the pass configuration, enhancing flexibility in compiler options.
* Modified the `JITKernel` class to handle device compile flags appropriately, ensuring they are included during compilation.
* Documented the new pass configuration key for clarity on usage and expected input formats.

* lint fix

* [Refactor] Simplify compile_flags handling in JIT functions

* Removed redundant string check for compile_flags in the compile, jit, and lazy_jit functions, ensuring compile_flags is consistently treated as a list.
* Updated the JITKernel class to handle compile_flags as a list when a string is provided, enhancing code clarity and maintainability.

* lint fix

* fix
parent 87e9e170
...@@ -36,6 +36,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); ...@@ -36,6 +36,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
DataType cuTensorMapType() { return DataType::UInt(8, 128); } DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......
...@@ -55,6 +55,7 @@ static constexpr const char *kLayoutVisualizationEnable = ...@@ -55,6 +55,7 @@ static constexpr const char *kLayoutVisualizationEnable =
"tl.layout_visualization_enable"; "tl.layout_visualization_enable";
static constexpr const char *kLayoutVisualizationFormats = static constexpr const char *kLayoutVisualizationFormats =
"tl.layout_visualization_formats"; "tl.layout_visualization_formats";
static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags";
/*! /*!
* \brief Whether to disable dynamic tail split * \brief Whether to disable dynamic tail split
* *
......
...@@ -76,21 +76,37 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None): ...@@ -76,21 +76,37 @@ def tilelang_callback_cuda_compile(code, target, pass_config=None):
# Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib # Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib
cfg = pass_config or {} cfg = pass_config or {}
if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, False): if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH, False):
deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7") deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7")
disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, True)) disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH, True))
enable_fast_math = not disable_fast_math enable_fast_math = not disable_fast_math
else: else:
enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH.value, False)) enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH, False))
ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL.value, None) ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, None)
verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT.value, False)) verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False))
options = [ options = [
"-std=c++17", "-std=c++17",
"-I" + tl_template_path, "-I" + tl_template_path,
"-I" + cutlass_path, "-I" + cutlass_path,
] ]
# Merge extra device compiler flags from pass config, if provided
extra_flags = cfg.get(PassConfigKey.TL_DEVICE_COMPILE_FLAGS, None)
if extra_flags:
import shlex
if isinstance(extra_flags, str):
tokens = shlex.split(extra_flags)
else:
tokens = []
for flag in extra_flags:
if isinstance(flag, str):
tokens.extend(shlex.split(flag))
else:
tokens.append(str(flag))
options += tokens
if enable_fast_math: if enable_fast_math:
options.append("--use_fast_math") options.append("--use_fast_math")
if ptxas_usage_level is not None: if ptxas_usage_level is not None:
......
...@@ -80,9 +80,6 @@ def compile( ...@@ -80,9 +80,6 @@ def compile(
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if hasattr(func, "out_idx_override"): if hasattr(func, "out_idx_override"):
if func.out_idx_override is not None and out_idx is not None: if func.out_idx_override is not None and out_idx is not None:
raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors")
...@@ -492,8 +489,6 @@ def jit( # This is the new public interface ...@@ -492,8 +489,6 @@ def jit( # This is the new public interface
Either a JIT-compiled wrapper around the input function, or a configured decorator Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function. instance that can then be applied to a function.
""" """
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
if isinstance(func, (PrimFunc, PrimFuncCreater)): if isinstance(func, (PrimFunc, PrimFuncCreater)):
...@@ -550,9 +545,6 @@ def lazy_jit( ...@@ -550,9 +545,6 @@ def lazy_jit(
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
): ):
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
compile_args = dict( compile_args = dict(
out_idx=None, out_idx=None,
execution_backend=execution_backend, execution_backend=execution_backend,
......
...@@ -19,6 +19,7 @@ from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonK ...@@ -19,6 +19,7 @@ from tilelang.jit.adapter import BaseKernelAdapter, CtypesKernelAdapter, CythonK
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc from tilelang.contrib import nvcc as tl_nvcc
from tilelang.transform import PassConfigKey
import logging import logging
import os import os
...@@ -96,7 +97,7 @@ class JITKernel(Generic[_P, _T]): ...@@ -96,7 +97,7 @@ class JITKernel(Generic[_P, _T]):
pass_configs = {} pass_configs = {}
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.compile_flags = compile_flags self.compile_flags = [compile_flags] if isinstance(compile_flags, str) else compile_flags
# Ensure the target is always a valid TVM Target object. # Ensure the target is always a valid TVM Target object.
self.target = determine_target(target, return_object=True) self.target = determine_target(target, return_object=True)
...@@ -218,10 +219,16 @@ class JITKernel(Generic[_P, _T]): ...@@ -218,10 +219,16 @@ class JITKernel(Generic[_P, _T]):
target_host = self.target_host target_host = self.target_host
execution_backend = self.execution_backend execution_backend = self.execution_backend
pass_configs = self.pass_configs pass_configs = self.pass_configs or {}
compile_flags = self.compile_flags compile_flags = self.compile_flags
if compile_flags is not None:
compile_flags_cfg = pass_configs.get(PassConfigKey.TL_DEVICE_COMPILE_FLAGS)
pass_configs[PassConfigKey.TL_DEVICE_COMPILE_FLAGS] = (
compile_flags_cfg + compile_flags if compile_flags_cfg is not None else compile_flags
)
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "tvm_ffi" enable_host_codegen = execution_backend == "tvm_ffi"
enable_device_compile = execution_backend == "tvm_ffi" enable_device_compile = execution_backend == "tvm_ffi"
......
...@@ -37,6 +37,20 @@ class PassConfigKey(str, Enum): ...@@ -37,6 +37,20 @@ class PassConfigKey(str, Enum):
TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output" TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output"
"""Enable ptxas verbose output. Default: False""" """Enable ptxas verbose output. Default: False"""
TL_DEVICE_COMPILE_FLAGS = "tl.device_compile_flags"
"""Additional device compiler flags passed to nvcc/NVRTC.
Accepts either a string (parsed with shell-like splitting) or a list of
strings. Typical usage is to provide extra include paths, defines or
ptxas options, e.g.:
- "-I/opt/include -DMY_SWITCH=1 --ptxas-options=--verbose"
- ["-I/opt/include", "-DMY_SWITCH=1", "--ptxas-options=--verbose"]
These flags are appended to the compiler options used in the tvm_ffi
CUDA compile callback. Default: None
"""
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth" TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32""" """Bitwidth for configuration indices. Default: 32"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment