Unverified Commit e7e38355 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Turn off `ENABLE_FAST_MATH` by default (#846)

* [Enhancement] Enable fast math optimization in tilelang JIT configurations

- Updated multiple examples and kernel functions to include `pass_configs` for enabling fast math optimization.
- Added support for the `TL_ENABLE_FAST_MATH` configuration option in the built-in operations.
- Enhanced the `LibraryGenerator` to handle the new fast math configuration, ensuring compatibility with existing settings.
- Updated documentation to reflect the changes in fast math handling and deprecation of the `TL_DISABLE_FAST_MATH` option.

* lint fix

* [Refactor] Introduce deprecated_warning utility for improved deprecation handling

- Added a new `deprecated_warning` function to streamline deprecation messages.
- Updated the `LibraryGenerator` to utilize the new function for warning about the deprecated `TL_DISABLE_FAST_MATH` configuration.
- Enhanced the `deprecated` decorator to support phaseout version messaging, improving clarity for users.
parent ebea77d9
...@@ -14,7 +14,10 @@ def get_configs(): ...@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, def flashattn(batch,
heads, heads,
seq_len, seq_len,
......
...@@ -14,7 +14,10 @@ def get_configs(): ...@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, def flashattn(batch,
heads, heads,
seq_len, seq_len,
......
...@@ -218,7 +218,10 @@ def attention_ref( ...@@ -218,7 +218,10 @@ def attention_ref(
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@tilelang.jit(out_idx=[6]) @tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch_size, def flashattn(batch_size,
UQ, UQ,
UKV, UKV,
......
...@@ -29,7 +29,10 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -29,7 +29,10 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask return dense_mask
@tilelang.jit(out_idx=[4]) @tilelang.jit(
out_idx=[4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
......
#!/usr/bin/env bash
# Usage: # Usage:
# # Do work and commit your work. # # Do work and commit your work.
......
...@@ -25,6 +25,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); ...@@ -25,6 +25,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
......
...@@ -40,6 +40,7 @@ static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; ...@@ -40,6 +40,7 @@ static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge = static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge"; "tl.enable_aggressive_shared_memory_merge";
static constexpr const char *kDisableFastMath = "tl.disable_fast_math"; static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
static constexpr const char *kEnableFastMath = "tl.enable_fast_math";
static constexpr const char *kPtxasRegisterUsageLevel = static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level"; "tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput = static constexpr const char *kEnablePTXASVerboseOutput =
......
...@@ -14,6 +14,7 @@ from tilelang.transform import PassConfigKey ...@@ -14,6 +14,7 @@ from tilelang.transform import PassConfigKey
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_arch, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH from tilelang.env import TILELANG_TEMPLATE_PATH
from tilelang.utils.deprecated import deprecated_warning
from .utils import is_cpu_target, is_cuda_target, is_hip_target from .utils import is_cpu_target, is_cuda_target, is_hip_target
...@@ -70,7 +71,17 @@ class LibraryGenerator(object): ...@@ -70,7 +71,17 @@ class LibraryGenerator(object):
target_arch = get_target_arch(get_target_compute_version(target)) target_arch = get_target_arch(get_target_compute_version(target))
libpath = src.name.replace(".cu", ".so") libpath = src.name.replace(".cu", ".so")
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False) if self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH):
deprecated_warning(
"TL_DISABLE_FAST_MATH",
"TL_ENABLE_FAST_MATH",
"0.1.7",
)
enable_fast_math = not self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH,
True)
else:
enable_fast_math = self.pass_configs.get(PassConfigKey.TL_ENABLE_FAST_MATH, False)
ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL, ptxas_usage_level = self.pass_configs.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL,
None) None)
verbose_ptxas_output = self.pass_configs.get( verbose_ptxas_output = self.pass_configs.get(
...@@ -91,7 +102,7 @@ class LibraryGenerator(object): ...@@ -91,7 +102,7 @@ class LibraryGenerator(object):
"-gencode", "-gencode",
f"arch=compute_{target_arch},code=sm_{target_arch}", f"arch=compute_{target_arch},code=sm_{target_arch}",
] ]
if not disable_fast_math: if enable_fast_math:
command += ["--use_fast_math"] command += ["--use_fast_math"]
if ptxas_usage_level is not None: if ptxas_usage_level is not None:
command += [f"--ptxas-options=--register-usage-level={ptxas_usage_level}"] command += [f"--ptxas-options=--register-usage-level={ptxas_usage_level}"]
......
...@@ -19,7 +19,15 @@ class PassConfigKey(str, Enum): ...@@ -19,7 +19,15 @@ class PassConfigKey(str, Enum):
"""Disable warp specialization optimization. Default: False""" """Disable warp specialization optimization. Default: False"""
TL_DISABLE_FAST_MATH = "tl.disable_fast_math" TL_DISABLE_FAST_MATH = "tl.disable_fast_math"
"""Disable fast math optimization. Default: False""" """Disable fast math optimization. Default: True
will be deprecated in the 0.1.7 release
"""
TL_ENABLE_FAST_MATH = "tl.enable_fast_math"
"""
Enable fast math optimization. Default: False
if enabled, --use_fast_math will be passed to nvcc
"""
TL_PTXAS_REGISTER_USAGE_LEVEL = "tl.ptxas_register_usage_level" TL_PTXAS_REGISTER_USAGE_LEVEL = "tl.ptxas_register_usage_level"
"""The PTXAS register usage level in [0, 10], which controls the """The PTXAS register usage level in [0, 10], which controls the
......
def deprecated_warning(method_name: str, new_method_name: str, phaseout_version: str = None):
"""A function to indicate that a method is deprecated
"""
import warnings # pylint: disable=import-outside-toplevel, import-error
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead" +
(f" and will be removed in {phaseout_version}" if phaseout_version else ""),
DeprecationWarning,
stacklevel=2,
)
def deprecated( def deprecated(
method_name: str, method_name: str,
new_method_name: str, new_method_name: str,
phaseout_version: str = None,
): ):
"""A decorator to indicate that a method is deprecated """A decorator to indicate that a method is deprecated
...@@ -10,19 +24,16 @@ def deprecated( ...@@ -10,19 +24,16 @@ def deprecated(
The name of the method to deprecate The name of the method to deprecate
new_method_name : str new_method_name : str
The name of the new method to use instead The name of the new method to use instead
phaseout_version : str
The version to phase out the method
""" """
import functools # pylint: disable=import-outside-toplevel import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel
def _deprecate(func): def _deprecate(func):
@functools.wraps(func) @functools.wraps(func)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
warnings.warn( deprecated_warning(method_name, new_method_name, phaseout_version)
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs) return func(*args, **kwargs)
return _wrapper return _wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment