Commit d7aebf4d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce option `TL_DISABLE_FAST_MATH` and `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` (#609)

* [Enhancement] Introduce new PassConfig options for fast math and PTXAS verbosity

- Added `kDisableFastMath` and `kEnablePTXASVerboseOutput` configuration options to enhance control over compilation settings.
- Updated `LibraryGenerator` to utilize these new pass configurations, allowing for more flexible compilation behavior based on user preferences.
- Enhanced `PassConfigKey` enumeration to include the new options, ensuring they can be configured appropriately in the pass context.

* [Refactor] Update PTXAS verbosity configuration key in LibraryGenerator

- Changed the configuration key for PTXAS verbosity from `TL_VERBOSE_PTXAS_OUTPUT` to `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` to align with the new naming convention introduced in recent enhancements.
- This update ensures consistency in the configuration options used within the `LibraryGenerator` class, improving clarity and maintainability of the code.

* lint fix
parent 0ff81755
......@@ -24,6 +24,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
......
......@@ -27,6 +27,9 @@ static constexpr const char *kDisableWarpSpecialized =
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";
static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
/*!
* \brief Whether to disable dynamic tail split
......
......@@ -88,6 +88,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
......@@ -143,6 +144,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.init()
......
......@@ -244,6 +244,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
......@@ -303,6 +304,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.get_last_error.restype = ctypes.c_char_p
......
......@@ -5,11 +5,12 @@ import os
import os.path as osp
import subprocess
import tempfile
from typing import Optional
from typing import Any, Dict, Optional
from tvm.target import Target
from tilelang import tvm as tvm
from tilelang.transform import PassConfigKey
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH
......@@ -34,10 +35,14 @@ class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, target: Target):
self.target = target
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
self.pass_configs = pass_configs
def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
......@@ -59,6 +64,10 @@ class LibraryGenerator(object):
compute_version = "90a"
libpath = src.name.replace(".cu", ".so")
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
verbose_ptxas_output = self.pass_configs.get(
PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
command = [
get_nvcc_compiler(),
"-std=c++17",
......@@ -74,6 +83,10 @@ class LibraryGenerator(object):
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
if not disable_fast_math:
command += ["--use_fast_math"]
if verbose_ptxas_output:
command += ["--ptxas_options", "-v"]
command += [
"-I" + CUTLASS_INCLUDE_DIR,
]
......
......@@ -18,6 +18,12 @@ class PassConfigKey(str, Enum):
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False"""
TL_DISABLE_FAST_MATH = "tl.disable_fast_math"
"""Disable fast math optimization. Default: False"""
TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output"
"""Enable ptxas verbose output. Default: False"""
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment