Unverified Commit e7e38355 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Turn off `ENABLE_FAST_MATH` by default (#846)

* [Enhancement] Enable fast math optimization in tilelang JIT configurations

- Updated multiple examples and kernel functions to include `pass_configs` for enabling fast math optimization.
- Added support for the `TL_ENABLE_FAST_MATH` configuration option in the built-in operations.
- Enhanced the `LibraryGenerator` to handle the new fast math configuration, ensuring compatibility with existing settings.
- Updated documentation to reflect the changes in fast math handling and deprecation of the `TL_DISABLE_FAST_MATH` option.

* lint fix

* [Refactor] Introduce deprecated_warning utility for improved deprecation handling

- Added a new `deprecated_warning` function to streamline deprecation messages.
- Updated the `LibraryGenerator` to utilize the new function for warning about the deprecated `TL_DISABLE_FAST_MATH` configuration.
- Enhanced the `deprecated` decorator to support phaseout version messaging, improving clarity for users.
parent ebea77d9
...@@ -29,7 +29,10 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -29,7 +29,10 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
return dense_mask return dense_mask
@tilelang.jit(out_idx=[4]) @tilelang.jit(
out_idx=[4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
......
...@@ -20,7 +20,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -20,7 +20,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1]) @tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages,
max_num_blocks_per_seq, max_selected_blocks): max_num_blocks_per_seq, max_selected_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
......
...@@ -15,7 +15,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -15,7 +15,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1]) @tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks): max_selected_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
......
...@@ -17,7 +17,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -17,7 +17,10 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit(out_idx=[-1]) @tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim]
......
...@@ -9,7 +9,10 @@ import argparse ...@@ -9,7 +9,10 @@ import argparse
tilelang.disable_cache() tilelang.disable_cache()
@tilelang.jit(out_idx=[6]) @tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashmla_decode(batch, def flashmla_decode(batch,
heads, heads,
kv_head_num, kv_head_num,
......
...@@ -7,7 +7,10 @@ from einops import rearrange, einsum ...@@ -7,7 +7,10 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[6]) @tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
softmax_scale): softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
......
...@@ -7,7 +7,10 @@ from tilelang.profiler import do_bench ...@@ -7,7 +7,10 @@ from tilelang.profiler import do_bench
import math import math
@tilelang.jit(out_idx=[8]) @tilelang.jit(
out_idx=[8], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size, softmax_scale): block_size, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
......
...@@ -8,7 +8,10 @@ from einops import rearrange, einsum ...@@ -8,7 +8,10 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[6]) @tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
......
...@@ -7,7 +7,10 @@ from einops import rearrange, einsum ...@@ -7,7 +7,10 @@ from einops import rearrange, einsum
import argparse import argparse
@tilelang.jit(out_idx=[-1]) @tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
......
...@@ -17,7 +17,9 @@ from einops import rearrange ...@@ -17,7 +17,9 @@ from einops import rearrange
import tilelang import tilelang
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_fwd( def tilelang_kernel_fwd(
batch, batch,
heads, heads,
...@@ -150,7 +152,9 @@ def tilelang_kernel_fwd( ...@@ -150,7 +152,9 @@ def tilelang_kernel_fwd(
return native_sparse_attention return native_sparse_attention
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dkv( def tilelang_kernel_bwd_dkv(
batch, batch,
heads, heads,
...@@ -314,7 +318,9 @@ def make_dq_layout(dQ): ...@@ -314,7 +318,9 @@ def make_dq_layout(dQ):
) )
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_bwd_dqkv( def tilelang_kernel_bwd_dqkv(
batch, batch,
heads, heads,
...@@ -477,7 +483,10 @@ def tilelang_kernel_bwd_dqkv( ...@@ -477,7 +483,10 @@ def tilelang_kernel_bwd_dqkv(
return flash_bwd_dqkv return flash_bwd_dqkv
@tilelang.jit(out_idx=[2]) @tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_preprocess( def tilelang_kernel_preprocess(
batch, batch,
heads, heads,
...@@ -514,7 +523,10 @@ def tilelang_kernel_preprocess( ...@@ -514,7 +523,10 @@ def tilelang_kernel_preprocess(
return flash_bwd_prep return flash_bwd_prep
@tilelang.jit(out_idx=[2]) @tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def tilelang_kernel_block_mask( def tilelang_kernel_block_mask(
batch, batch,
heads, heads,
......
...@@ -15,6 +15,7 @@ tilelang.testing.set_random_seed(42) ...@@ -15,6 +15,7 @@ tilelang.testing.set_random_seed(42)
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def native_sparse_attention( def native_sparse_attention(
batch, batch,
......
...@@ -8,7 +8,10 @@ import tilelang.testing ...@@ -8,7 +8,10 @@ import tilelang.testing
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def native_sparse_attention(batch, def native_sparse_attention(batch,
heads, heads,
seq_len, seq_len,
......
...@@ -16,7 +16,9 @@ from reference import naive_nsa ...@@ -16,7 +16,9 @@ from reference import naive_nsa
from einops import rearrange from einops import rearrange
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def native_sparse_attention_varlen(batch, def native_sparse_attention_varlen(batch,
heads, heads,
c_seq_len, c_seq_len,
......
...@@ -5,7 +5,10 @@ import tilelang.language as T ...@@ -5,7 +5,10 @@ import tilelang.language as T
import argparse import argparse
@tilelang.jit(out_idx=[3, 4]) @tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -77,7 +80,10 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -77,7 +80,10 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
return flash_fwd return flash_fwd
@tilelang.jit(out_idx=[2]) @tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -113,7 +119,10 @@ def make_dq_layout(dQ): ...@@ -113,7 +119,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1]) @tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -135,7 +144,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): ...@@ -135,7 +144,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
return flash_bwd_post return flash_bwd_post
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
sm_scale = (1.0 / dim_qk)**0.5 sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
......
...@@ -58,7 +58,10 @@ def get_configs(user_config=None): ...@@ -58,7 +58,10 @@ def get_configs(user_config=None):
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, def flashattn(batch,
heads, heads,
seq_len, seq_len,
......
...@@ -23,7 +23,10 @@ def get_configs(): ...@@ -23,7 +23,10 @@ def get_configs():
warmup=10, warmup=10,
rep=10, rep=10,
) )
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn( def flashattn(
batch, batch,
heads, heads,
......
...@@ -6,7 +6,10 @@ import tilelang.language as T ...@@ -6,7 +6,10 @@ import tilelang.language as T
import argparse import argparse
@tilelang.jit(out_idx=[3, 4]) @tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -79,7 +82,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -79,7 +82,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
return flash_fwd return flash_fwd
@tilelang.jit(out_idx=[2]) @tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -115,7 +121,10 @@ def make_dq_layout(dQ): ...@@ -115,7 +121,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1]) @tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -137,7 +146,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -137,7 +146,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......
...@@ -6,7 +6,10 @@ import tilelang.language as T ...@@ -6,7 +6,10 @@ import tilelang.language as T
import argparse import argparse
@tilelang.jit(out_idx=[3, 4]) @tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -76,7 +79,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -76,7 +79,10 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
return flash_fwd return flash_fwd
@tilelang.jit(out_idx=[2]) @tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -112,7 +118,10 @@ def make_dq_layout(dQ): ...@@ -112,7 +118,10 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1]) @tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -134,7 +143,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -134,7 +143,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit @tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......
...@@ -14,7 +14,10 @@ def get_configs(): ...@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, def flashattn(batch,
heads, heads,
seq_q, seq_q,
......
...@@ -14,7 +14,10 @@ def get_configs(): ...@@ -14,7 +14,10 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3]) @tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, def flashattn(batch,
heads, heads,
seq_q, seq_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment