Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
......@@ -67,4 +67,5 @@ jobs:
run: |
source tilelang_ci/bin/activate
cd testing/python
export TILELANG_CLEAR_CACHE=1
python -m pytest
Subproject commit 2654ce86a8cda7d28eab73db7e9104c90511c072
Subproject commit c1c2a08a53f24886d2f82839fe304f2f1b6d0973
# Copyright(c) Microsoft Corporation.
# Licensed under the MIT License.
# Learn a lot from the MLC - LLM Project
# https: // github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt
......
......@@ -87,7 +87,7 @@ Or install locally:
sudo apt-get update
sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
pip install . # with -e option if you want to install in editable mode
pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output
```
### Method 2: Build from Source
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from tilelang.profiler import do_bench
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import math
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import logging
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: E712
import math
import torch
......
import torch
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
......@@ -145,14 +144,14 @@ if __name__ == "__main__":
N, C, H, W, F, K, S, D, P, tune=args.tune)(
block_M=256, block_N=128, block_K=64, num_stages=4, threads=256)
ref_program = partial(ref_program, stride=S, padding=P, dilation=D)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
......
......@@ -145,8 +145,8 @@ def calc_diff(x, y):
def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(gemm)
src_code = mod.imported_modules[0].get_source()
kernel = TL.compile(gemm, out_idx=[])
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -162,16 +162,15 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A_fp8, B_fp8, C, A_scale, B_scale)
kernel(A_fp8, B_fp8, C, A_scale, B_scale)
# Get Reference Result
ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
diff = calc_diff(C, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
latency = mod.do_bench(mod.func, warmup=25)
profiler = kernel.get_profiler()
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
print(f"latency: {latency} ms")
......
# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py
# ruff: noqa
import argparse
import math
import random
......@@ -16,6 +16,7 @@ import tilelang
from tilelang.profiler import do_bench
from example_mla_decode_paged import mla_decode_tilelang
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
......@@ -37,7 +38,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
......@@ -50,7 +52,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv,
h_q,
h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
......@@ -61,16 +64,24 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla():
return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal,
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
out_flash, lse_flash = flash_mla()
......@@ -79,13 +90,14 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
kv_indptr = [0]
kv_indices = []
for i in range(b):
......@@ -96,15 +108,13 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8),
backend="fa3"
)
torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
mla_wrapper.plan(
q_indptr,
kv_indptr,
......@@ -112,7 +122,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
cache_seqlens,
h_q,
dv,
d-dv,
d - dv,
block_size,
causal,
1 / math.sqrt(d),
......@@ -121,7 +131,12 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
)
def flash_infer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)
output, lse = mla_wrapper.run(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope,
blocked_k_pe,
return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer()
......@@ -164,7 +179,8 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[
None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
......@@ -210,7 +226,9 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
offs_o = cur_batch * stride_o_b + cur_head[:,
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
......@@ -260,13 +278,14 @@ def _mla_attn(
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
......@@ -310,7 +329,7 @@ def _mla_softmax_reducev_kernel(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
......@@ -340,6 +359,7 @@ def _mla_softmax_reducev(
num_stages=2,
)
def mla_decode_triton(
q_nope,
q_pe,
......@@ -372,22 +392,27 @@ def mla_decode_triton(
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)
mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits,
num_kv_splits, 1 / math.sqrt(d), block_size)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
......@@ -396,31 +421,32 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@torch.inference_mode()
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = tilelang.compile(program, out_idx=[8])
def flash_mla_tilelang():
out = mod.func(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
out = kernel.func(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
cache_seqlens,
glse,
out_partial,
......@@ -431,6 +457,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
t = do_bench(flash_mla_tilelang)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"tilelang": run_flash_mla_tilelang,
......@@ -438,9 +465,12 @@ FUNC_TABLE = {
"flash_infer": run_flash_infer,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
print(
f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
......@@ -451,21 +481,23 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton", "flash_mla_tilelang"]:
# flash_infer has a different lse return value
......@@ -473,14 +505,21 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
print(
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
......@@ -489,24 +528,28 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
random.seed(0)
assert target in FUNC_TABLE
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_b
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_b
available_targets = [
......@@ -517,10 +560,26 @@ available_targets = [
"flash_mla_triton",
]
shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.float16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]
]
shape_configs = [{
"b":
batch,
"s_q":
1,
"cache_seqlens":
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q":
head,
"h_kv":
1,
"d":
512 + 64,
"dv":
512,
"causal":
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]]
def get_args():
......@@ -533,7 +592,7 @@ def get_args():
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
......@@ -542,12 +601,26 @@ if __name__ == "__main__":
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
)
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"],
shape["cache_seqlens"], shape["h_q"], shape["h_kv"],
shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n'
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n'
)
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
\ No newline at end of file
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
tilelang.testing.set_random_seed(0)
......@@ -115,10 +112,10 @@ def run_gemm(
num_threads,
)
mod, params = TL.lower(program)
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = mod.run_once()
out = profiler.run_once()
assert out is not None
def ref_program(A, qB):
......@@ -134,7 +131,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program)
profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
......@@ -363,8 +360,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
# src_code is the generated cuda source
assert src_code is not None
......@@ -402,11 +400,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
kernel(A, QLB, C)
mod(A, QLB, C)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import Profiler
import tilelang.language as T
from tilelang.autotuner import *
from tilelang import tvm
from tvm import tir
import itertools
import torch
import argparse
from functools import partial
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
......@@ -103,11 +97,10 @@ def test_fp4_fp16_convert_close():
"float16",
)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
......@@ -291,14 +284,14 @@ if __name__ == "__main__":
program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment