"git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "1013dce60c198ff5217a5e4d8c384190f54bdc42"
Commit 889451eb authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Refactor] Update kernel compilation and profiling in examples (#225)

- Replaced instances of `tilelang.lower` and `tilelang.Profiler` with `tilelang.compile` and the new profiler interface in multiple example files.
- Enhanced the kernel compilation process to utilize the updated API, improving consistency and maintainability.
- Adjusted benchmarking logic to use the new profiler methods for better clarity and functionality in performance testing.
- Cleaned up whitespace and improved formatting for better readability across the modified files.
parent c5bbc608
...@@ -288,10 +288,9 @@ if __name__ == "__main__": ...@@ -288,10 +288,9 @@ if __name__ == "__main__":
num_split = 1 num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[6])
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All close") latency = profiler.do_bench(warmup=500)
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch") print(f"Latency: {latency} ms")
print("Tile-lang: {:.2f} ms".format(latency)) print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
import torch import torch
import torch.nn.functional as F
import tilelang import tilelang
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
from einops import rearrange, einsum
import argparse import argparse
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
import math import math
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size):
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -59,7 +59,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -59,7 +59,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2): for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size kv_start = BLOCK_TABLE[bx, (k * block_N) //
block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
...@@ -75,7 +76,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -75,7 +76,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
if kr == 0: if kr == 0:
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
...@@ -105,7 +107,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -105,7 +107,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
glse: T.Buffer([batch, h_q, num_split], dtype), glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
): ):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): with T.Kernel(
batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype) Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
...@@ -141,7 +144,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -141,7 +144,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size kv_start = BLOCK_TABLE[bx, (start + k * block_N) //
block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s) T.clear(acc_s)
...@@ -156,7 +160,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -156,7 +160,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
...@@ -227,7 +232,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -227,7 +232,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype), Output: T.Buffer([batch, h_q, dv], dtype),
): ):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse,
Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
...@@ -249,6 +255,7 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -249,6 +255,7 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
else: else:
return main_no_split return main_no_split
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() query = query.float()
key = key.float() key = key.float()
...@@ -260,7 +267,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -260,7 +267,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q = query.shape[-2] s_q = query.shape[-2]
s_k = key.shape[-2] s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) temp_mask = torch.ones(
s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype) attn_bias.to(query.dtype)
attn_weight += attn_bias attn_weight += attn_bias
...@@ -270,7 +278,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -270,7 +278,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d] # q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size] # block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
...@@ -287,7 +296,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -287,7 +296,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
q[i].transpose(0, 1), q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv, h_q,
h_kv,
is_causal=causal, is_causal=causal,
) )
out[i] = O.transpose(0, 1) out[i] = O.transpose(0, 1)
...@@ -298,31 +308,33 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -298,31 +308,33 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return out_torch return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
dpe = d - dv dpe = d - dv
num_kv_splits = 1 num_kv_splits = 1
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = 64 BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device) program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) num_kv_splits, block_size)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[8])
mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang(): def flash_mla_tilelang():
out = mod.func( out = profiler.func(
q_nope.view(-1, h_q, dv), q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe), q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv), blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe), blocked_k_pe.view(-1, h_kv, dpe),
block_table, block_table,
cache_seqlens, cache_seqlens,
glse, glse,
out_partial, out_partial,
...@@ -331,11 +343,13 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -331,11 +343,13 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash = flash_mla_tilelang() out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang) t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close") print("All close")
return out_flash, t return out_flash, t
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument('--batch', type=int, default=128, help='batch size')
...@@ -349,10 +363,12 @@ if __name__ == "__main__": ...@@ -349,10 +363,12 @@ if __name__ == "__main__":
device = "cuda" device = "cuda"
dtype = torch.float16 dtype = torch.float16
s_q = 1 # for decode, s_q = 1 s_q = 1 # for decode, s_q = 1
block_size = 64 block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)],
dtype=torch.int32,
device=device)
dpe = d - dv dpe = d - dv
causal = True causal = True
...@@ -364,9 +380,12 @@ if __name__ == "__main__": ...@@ -364,9 +380,12 @@ if __name__ == "__main__":
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 total_flops = s_q * total_seqlens * h_q * (d + dv) * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32,
device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang import Profiler
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import itertools import itertools
...@@ -256,14 +255,14 @@ if __name__ == "__main__": ...@@ -256,14 +255,14 @@ if __name__ == "__main__":
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)( batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)(
block_M=128, block_N=128, num_stages=2, threads=128) block_M=128, block_N=128, num_stages=2, threads=128)
ref_program = partial(ref_program, is_causal=is_causal, groups=groups) ref_program = partial(ref_program, is_causal=is_causal, groups=groups)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[3])
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang import Profiler
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import itertools import itertools
...@@ -228,14 +227,14 @@ if __name__ == "__main__": ...@@ -228,14 +227,14 @@ if __name__ == "__main__":
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)( batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_causal=is_causal, groups=groups) ref_program = partial(ref_program, is_causal=is_causal, groups=groups)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[3])
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -221,7 +221,7 @@ if __name__ == "__main__": ...@@ -221,7 +221,7 @@ if __name__ == "__main__":
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(profiler.mod, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang import Profiler
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import itertools import itertools
...@@ -209,14 +208,14 @@ if __name__ == "__main__": ...@@ -209,14 +208,14 @@ if __name__ == "__main__":
batch, heads, seq_len, dim, is_causal, tune=args.tune)( batch, heads, seq_len, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=1, threads=128) block_M=128, block_N=128, num_stages=1, threads=128)
ref_program = partial(ref_program, is_causal=is_causal) ref_program = partial(ref_program, is_causal=is_causal)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[3])
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler()
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang import Profiler
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import itertools import itertools
...@@ -214,14 +213,14 @@ if __name__ == "__main__": ...@@ -214,14 +213,14 @@ if __name__ == "__main__":
batch, heads, seq_len, dim, is_causal, tune=args.tune)( batch, heads, seq_len, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_causal=is_causal) ref_program = partial(ref_program, is_causal=is_causal)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[3])
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment