"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "f84a17b86e2c546883da35b41d2e0d4f466d7c27"
Unverified Commit e7e38355 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Turn off `ENABLE_FAST_MATH` by default (#846)

* [Enhancement] Enable fast math optimization in tilelang JIT configurations

- Updated multiple examples and kernel functions to include `pass_configs` for enabling fast math optimization.
- Added support for the `TL_ENABLE_FAST_MATH` configuration option in the built-in operations.
- Enhanced the `LibraryGenerator` to handle the new fast math configuration, ensuring compatibility with existing settings.
- Updated documentation to reflect the changes in fast math handling and deprecation of the `TL_DISABLE_FAST_MATH` option.

* lint fix

* [Refactor] Introduce deprecated_warning utility for improved deprecation handling

- Added a new `deprecated_warning` function to streamline deprecation messages.
- Updated the `LibraryGenerator` to utilize the new function for warning about the deprecated `TL_DISABLE_FAST_MATH` configuration.
- Enhanced the `deprecated` decorator to support phaseout version messaging, improving clarity for users.
parent ebea77d9
......@@ -29,7 +29,10 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask
@tilelang.jit(out_idx=[4])
@tilelang.jit(
out_idx=[4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
......
......@@ -20,7 +20,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages,
max_num_blocks_per_seq, max_selected_blocks):
shape_q = [batch, heads, dim]
......
......@@ -15,7 +15,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks):
shape_q = [batch, heads, dim]
......
......@@ -17,7 +17,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1])
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
......
......@@ -9,7 +9,10 @@ import argparse
tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashmla_decode(batch,
heads,
kv_head_num,
......
......@@ -7,7 +7,10 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[6])
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
......
......@@ -7,7 +7,10 @@ from tilelang.profiler import do_bench
import math
@tilelang.jit(out_idx=[8])
@tilelang.jit(
out_idx=[8], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
......
......@@ -8,7 +8,10 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[6])
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......
......@@ -7,7 +7,10 @@ from einops import rearrange, einsum
import argparse
@tilelang.jit(out_idx=[-1])
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......
......@@ -17,7 +17,9 @@ from einops import rearrange
import tilelang
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_fwd(
batch,
heads,
......@@ -150,7 +152,9 @@ def tilelang_kernel_fwd(
return native_sparse_attention
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dkv(
batch,
heads,
......@@ -314,7 +318,9 @@ def make_dq_layout(dQ):
)
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dqkv(
batch,
heads,
......@@ -477,7 +483,10 @@ def tilelang_kernel_bwd_dqkv(
return flash_bwd_dqkv
@tilelang.jit(out_idx=[2])
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_preprocess(
batch,
heads,
......@@ -514,7 +523,10 @@ def tilelang_kernel_preprocess(
return flash_bwd_prep
@tilelang.jit(out_idx=[2])
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_block_mask(
batch,
heads,
......
......@@ -15,6 +15,7 @@ tilelang.testing.set_random_seed(42)
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def native_sparse_attention(
batch,
......
......@@ -8,7 +8,10 @@ import tilelang.testing
tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[-1])
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def native_sparse_attention(batch,
heads,
seq_len,
......
......@@ -16,7 +16,9 @@ from reference import naive_nsa
from einops import rearrange
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def native_sparse_attention_varlen(batch,
heads,
c_seq_len,
......
......@@ -5,7 +5,10 @@ import tilelang.language as T
import argparse
@tilelang.jit(out_idx=[3, 4])
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
......@@ -77,7 +80,10 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
return flash_fwd
@tilelang.jit(out_idx=[2])
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
......@@ -113,7 +119,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16"
accum_dtype = "float"
......@@ -135,7 +144,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
return flash_bwd_post
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
......
......@@ -58,7 +58,10 @@ def get_configs(user_config=None):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_len,
......
......@@ -23,7 +23,10 @@ def get_configs():
warmup=10,
rep=10,
)
@tilelang.jit(out_idx=[3])
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
......
......@@ -6,7 +6,10 @@ import tilelang.language as T
import argparse
@tilelang.jit(out_idx=[3, 4])
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
......@@ -79,7 +82,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
return flash_fwd
@tilelang.jit(out_idx=[2])
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -115,7 +121,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -137,7 +146,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......
......@@ -6,7 +6,10 @@ import tilelang.language as T
import argparse
@tilelang.jit(out_idx=[3, 4])
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
......@@ -76,7 +79,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
return flash_fwd
@tilelang.jit(out_idx=[2])
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -112,7 +118,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -134,7 +143,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post
@tilelang.jit
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......
......@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
......
......@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment