Unverified Commit 0c7e7419 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Cleanup] Remove `tilelang.disable_cache()` calls from examples and tests (#1088)

* [Cleanup] Remove `tilelang.disable_cache()` calls from example scripts

* lint

* lint
parent 42c267e8
...@@ -11,8 +11,6 @@ import math ...@@ -11,8 +11,6 @@ import math
from heuristic import num_splits_heuristic from heuristic import num_splits_heuristic
tilelang.disable_cache()
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......
...@@ -4,8 +4,6 @@ import tilelang.language as T ...@@ -4,8 +4,6 @@ import tilelang.language as T
from typing import Tuple from typing import Tuple
from tilelang.utils.tensor import torch_assert_close from tilelang.utils.tensor import torch_assert_close
tilelang.disable_cache()
@tilelang.jit(out_idx=[1, 2]) @tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m): def per_token_cast_to_fp8(M, N, blk_m):
......
...@@ -5,8 +5,6 @@ import tilelang.language as T ...@@ -5,8 +5,6 @@ import tilelang.language as T
from einops import rearrange, einsum from einops import rearrange, einsum
import argparse import argparse
tilelang.disable_cache()
def get_configs(): def get_configs():
import itertools import itertools
......
...@@ -5,8 +5,6 @@ import tilelang ...@@ -5,8 +5,6 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
tilelang.disable_cache()
def ref_program(x, y): def ref_program(x, y):
return x + y return x + y
......
...@@ -7,7 +7,6 @@ import argparse ...@@ -7,7 +7,6 @@ import argparse
from einops import rearrange, repeat from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input from bert_padding import pad_input, unpad_input
# tilelang.disable_cache()
torch.manual_seed(1) torch.manual_seed(1)
......
...@@ -24,8 +24,6 @@ import torch.nn.functional as F ...@@ -24,8 +24,6 @@ import torch.nn.functional as F
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
tilelang.disable_cache()
from utils import * from utils import *
......
...@@ -32,8 +32,6 @@ from utils import * ...@@ -32,8 +32,6 @@ from utils import *
torch.random.manual_seed(0) torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input( def prepare_input(
B, B,
......
...@@ -19,8 +19,6 @@ import torch ...@@ -19,8 +19,6 @@ import torch
torch.random.manual_seed(1) torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input( def prepare_input(
B, B,
......
...@@ -26,8 +26,6 @@ from utils import * ...@@ -26,8 +26,6 @@ from utils import *
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
tilelang.disable_cache()
def prepare_input_fake( def prepare_input_fake(
B, B,
......
...@@ -20,8 +20,6 @@ import torch ...@@ -20,8 +20,6 @@ import torch
torch.set_printoptions(profile="full") torch.set_printoptions(profile="full")
torch.random.manual_seed(0) torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input( def prepare_input(
B, B,
......
...@@ -18,8 +18,6 @@ except ImportError: ...@@ -18,8 +18,6 @@ except ImportError:
import torch import torch
tilelang.disable_cache()
@tilelang.jit( @tilelang.jit(
out_idx=[-1], out_idx=[-1],
......
...@@ -19,8 +19,6 @@ import torch ...@@ -19,8 +19,6 @@ import torch
torch.random.manual_seed(1) torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32): def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
BS = chunk_size BS = chunk_size
......
...@@ -22,8 +22,6 @@ import torch.nn.functional as F ...@@ -22,8 +22,6 @@ import torch.nn.functional as F
torch.random.manual_seed(0) torch.random.manual_seed(0)
torch.set_printoptions(profile="full") torch.set_printoptions(profile="full")
tilelang.disable_cache()
def prepare_input_fake( def prepare_input_fake(
B, B,
......
...@@ -2,8 +2,6 @@ import torch ...@@ -2,8 +2,6 @@ import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
def matmul( def matmul(
M, M,
......
...@@ -4,8 +4,6 @@ import argparse ...@@ -4,8 +4,6 @@ import argparse
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2], pass_configs={
......
...@@ -12,8 +12,6 @@ import tilelang ...@@ -12,8 +12,6 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
tilelang.disable_cache()
@tilelang.jit(out_idx=[3]) @tilelang.jit(out_idx=[3])
def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size): def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_size):
......
...@@ -6,8 +6,6 @@ import tilelang.language as T ...@@ -6,8 +6,6 @@ import tilelang.language as T
from einops import rearrange, einsum from einops import rearrange, einsum
import argparse import argparse
tilelang.disable_cache()
@tilelang.jit(out_idx=[6]) @tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
......
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
......
...@@ -50,8 +50,6 @@ def main(M=16384, N=16384, K=16384): ...@@ -50,8 +50,6 @@ def main(M=16384, N=16384, K=16384):
jit_kernel = matmul(M, N, K, block_M, block_N, block_K) jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
tilelang.disable_cache()
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
......
...@@ -17,8 +17,6 @@ from triton.language.extra import libdevice ...@@ -17,8 +17,6 @@ from triton.language.extra import libdevice
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment