"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "bb853e004a46c91d9c62bf7823573e249ad11a23"
Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
...@@ -67,4 +67,5 @@ jobs: ...@@ -67,4 +67,5 @@ jobs:
run: | run: |
source tilelang_ci/bin/activate source tilelang_ci/bin/activate
cd testing/python cd testing/python
export TILELANG_CLEAR_CACHE=1
python -m pytest python -m pytest
Subproject commit 2654ce86a8cda7d28eab73db7e9104c90511c072 Subproject commit c1c2a08a53f24886d2f82839fe304f2f1b6d0973
# Copyright(c) Microsoft Corporation.
# Licensed under the MIT License.
# Learn a lot from the MLC - LLM Project # Learn a lot from the MLC - LLM Project
# https: // github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt # https: // github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt
......
...@@ -87,7 +87,7 @@ Or install locally: ...@@ -87,7 +87,7 @@ Or install locally:
sudo apt-get update sudo apt-get update
sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
pip install . # with -e option if you want to install in editable mode pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output
``` ```
### Method 2: Build from Source ### Method 2: Build from Source
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import torch import torch
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import math import math
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import math import math
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import math import math
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse import argparse
import logging import logging
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: E712 # ruff: noqa: E712
import math import math
import torch import torch
......
import torch import torch
import tilelang import tilelang
from tilelang import Profiler
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import itertools import itertools
...@@ -145,14 +144,14 @@ if __name__ == "__main__": ...@@ -145,14 +144,14 @@ if __name__ == "__main__":
N, C, H, W, F, K, S, D, P, tune=args.tune)( N, C, H, W, F, K, S, D, P, tune=args.tune)(
block_M=256, block_N=128, block_K=64, num_stages=4, threads=256) block_M=256, block_N=128, block_K=64, num_stages=4, threads=256)
ref_program = partial(ref_program, stride=S, padding=P, dilation=D) ref_program = partial(ref_program, stride=S, padding=P, dilation=D)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
...@@ -145,8 +145,8 @@ def calc_diff(x, y): ...@@ -145,8 +145,8 @@ def calc_diff(x, y):
def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, in_dtype, out_dtype, accum_dtype) gemm = tl_gemm(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(gemm) kernel = TL.compile(gemm, out_idx=[])
src_code = mod.imported_modules[0].get_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -162,16 +162,15 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -162,16 +162,15 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
C = torch.zeros(M, N, device="cuda", dtype=out_dtype) C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) kernel(A_fp8, B_fp8, C, A_scale, B_scale)
mod(A_fp8, B_fp8, C, A_scale, B_scale)
# Get Reference Result # Get Reference Result
ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype) ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
diff = calc_diff(C, ref_c) diff = calc_diff(C, ref_c)
print(f"diff: {diff}") print(f"diff: {diff}")
assert diff < 1e-3 assert diff < 1e-3
latency = mod.do_bench(mod.func, warmup=25) profiler = kernel.get_profiler()
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
print(f"latency: {latency} ms") print(f"latency: {latency} ms")
......
# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py # This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py
# ruff: noqa
import argparse import argparse
import math import math
import random import random
...@@ -16,6 +16,7 @@ import tilelang ...@@ -16,6 +16,7 @@ import tilelang
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
from example_mla_decode_paged import mla_decode_tilelang from example_mla_decode_paged import mla_decode_tilelang
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() query = query.float()
key = key.float() key = key.float()
...@@ -37,7 +38,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): ...@@ -37,7 +38,8 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode() @torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
def ref_mla(): def ref_mla():
...@@ -50,7 +52,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -50,7 +52,8 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
q[i].transpose(0, 1), q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv, h_q,
h_kv,
is_causal=causal, is_causal=causal,
) )
out[i] = O.transpose(0, 1) out[i] = O.transpose(0, 1)
...@@ -61,16 +64,24 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -61,16 +64,24 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
t = triton.testing.do_bench(ref_mla) t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t return out_torch, lse_torch, t
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache( return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv, q,
tile_scheduler_metadata, num_splits, causal=causal, blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
) )
out_flash, lse_flash = flash_mla() out_flash, lse_flash = flash_mla()
...@@ -79,13 +90,14 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -79,13 +90,14 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode() @torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
kv_indptr = [0] kv_indptr = [0]
kv_indices = [] kv_indices = []
for i in range(b): for i in range(b):
...@@ -96,15 +108,13 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ ...@@ -96,15 +108,13 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
kv_indptr.append(kv_indptr[-1] + num_blocks) kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]: for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8), torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
backend="fa3"
)
mla_wrapper.plan( mla_wrapper.plan(
q_indptr, q_indptr,
kv_indptr, kv_indptr,
...@@ -112,7 +122,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ ...@@ -112,7 +122,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
cache_seqlens, cache_seqlens,
h_q, h_q,
dv, dv,
d-dv, d - dv,
block_size, block_size,
causal, causal,
1 / math.sqrt(d), 1 / math.sqrt(d),
...@@ -121,7 +131,12 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ ...@@ -121,7 +131,12 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
) )
def flash_infer(): def flash_infer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True) output, lse = mla_wrapper.run(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope,
blocked_k_pe,
return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer() out_flash, lse_flash = flash_infer()
...@@ -164,7 +179,8 @@ def _mla_attn_kernel( ...@@ -164,7 +179,8 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[
None, :]
q_nope = tl.load(Q_nope + offs_q_nope) q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
...@@ -210,7 +226,9 @@ def _mla_attn_kernel( ...@@ -210,7 +226,9 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1) e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] offs_o = cur_batch * stride_o_b + cur_head[:,
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
tl.store(O + offs_o, acc / e_sum[:, None]) tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum)) tl.store(O + offs_o_1, e_max + tl.log(e_sum))
...@@ -260,13 +278,14 @@ def _mla_attn( ...@@ -260,13 +278,14 @@ def _mla_attn(
attn_logits.stride(1), attn_logits.stride(1),
attn_logits.stride(2), attn_logits.stride(2),
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits, NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv, HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe, HEAD_DIM_KPE=head_dim_kpe,
) )
@triton.jit @triton.jit
def _mla_softmax_reducev_kernel( def _mla_softmax_reducev_kernel(
Logits, Logits,
...@@ -310,7 +329,7 @@ def _mla_softmax_reducev_kernel( ...@@ -310,7 +329,7 @@ def _mla_softmax_reducev_kernel(
e_sum = e_sum * old_scale + exp_logic e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max e_max = n_e_max
tl.store( tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum, acc / e_sum,
...@@ -340,6 +359,7 @@ def _mla_softmax_reducev( ...@@ -340,6 +359,7 @@ def _mla_softmax_reducev(
num_stages=2, num_stages=2,
) )
def mla_decode_triton( def mla_decode_triton(
q_nope, q_nope,
q_pe, q_pe,
...@@ -372,22 +392,27 @@ def mla_decode_triton( ...@@ -372,22 +392,27 @@ def mla_decode_triton(
b_seq_len, b_seq_len,
num_kv_splits, num_kv_splits,
) )
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
def flash_mla_triton(): def flash_mla_triton():
num_kv_splits = 32 num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv]) o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size) mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits,
num_kv_splits, 1 / math.sqrt(d), block_size)
return o.view([b, s_q, h_q, dv]) return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton() out_flash = flash_mla_triton()
...@@ -396,31 +421,32 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, ...@@ -396,31 +421,32 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@torch.inference_mode() @torch.inference_mode()
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
dpe = d - dv dpe = d - dv
num_kv_splits = 1 num_kv_splits = 1
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = 64 BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device) program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) num_kv_splits, block_size)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[8])
mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang(): def flash_mla_tilelang():
out = mod.func( out = kernel.func(
q_nope.view(-1, h_q, dv), q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe), q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv), blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe), blocked_k_pe.view(-1, h_kv, dpe),
block_table, block_table,
cache_seqlens, cache_seqlens,
glse, glse,
out_partial, out_partial,
...@@ -431,6 +457,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size ...@@ -431,6 +457,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
t = do_bench(flash_mla_tilelang) t = do_bench(flash_mla_tilelang)
return out_flash, None, t return out_flash, None, t
FUNC_TABLE = { FUNC_TABLE = {
"torch": run_torch_mla, "torch": run_torch_mla,
"tilelang": run_flash_mla_tilelang, "tilelang": run_flash_mla_tilelang,
...@@ -438,9 +465,12 @@ FUNC_TABLE = { ...@@ -438,9 +465,12 @@ FUNC_TABLE = {
"flash_infer": run_flash_infer, "flash_infer": run_flash_infer,
"flash_mla_triton": run_flash_mla_triton, "flash_mla_triton": run_flash_mla_triton,
} }
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") print(
f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -451,21 +481,23 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -451,21 +481,23 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
assert target in FUNC_TABLE assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline] baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target] target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item() total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item() max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton", "flash_mla_tilelang"]: if target not in ["flash_infer", "flash_mla_triton", "flash_mla_tilelang"]:
# flash_infer has a different lse return value # flash_infer has a different lse return value
...@@ -473,14 +505,21 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -473,14 +505,21 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") print(
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") print(
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_device(device) torch.set_default_device(device)
...@@ -489,24 +528,28 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): ...@@ -489,24 +528,28 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
random.seed(0) random.seed(0)
assert target in FUNC_TABLE assert target in FUNC_TABLE
target_func = FUNC_TABLE[target] target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item() total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item() max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") torch.finfo(dtype).bits // 8)
return bytes / 10 ** 6 / perf_b print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
return bytes / 10**6 / perf_b
available_targets = [ available_targets = [
...@@ -517,10 +560,26 @@ available_targets = [ ...@@ -517,10 +560,26 @@ available_targets = [
"flash_mla_triton", "flash_mla_triton",
] ]
shape_configs = [ shape_configs = [{
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.float16} "b":
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128] batch,
] "s_q":
1,
"cache_seqlens":
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q":
head,
"h_kv":
1,
"d":
512 + 64,
"dv":
512,
"causal":
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]]
def get_args(): def get_args():
...@@ -533,7 +592,7 @@ def get_args(): ...@@ -533,7 +592,7 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
...@@ -542,12 +601,26 @@ if __name__ == "__main__": ...@@ -542,12 +601,26 @@ if __name__ == "__main__":
for shape in shape_configs: for shape in shape_configs:
if args.all: if args.all:
for target in available_targets: for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"],
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
)
elif args.compare: elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"],
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') shape["cache_seqlens"], shape["h_q"], shape["h_kv"],
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n'
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n'
)
elif args.one: elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"],
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
\ No newline at end of file shape["causal"], shape["dtype"])
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -115,10 +112,10 @@ def run_gemm( ...@@ -115,10 +112,10 @@ def run_gemm(
num_threads, num_threads,
) )
mod, params = TL.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = mod.run_once() out = profiler.run_once()
assert out is not None assert out is not None
def ref_program(A, qB): def ref_program(A, qB):
...@@ -134,7 +131,7 @@ def run_gemm( ...@@ -134,7 +131,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program) profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas") @tvm.testing.requires_package("bitblas")
...@@ -363,8 +360,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -363,8 +360,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() src_code = kernel.get_kernel_source()
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -402,11 +400,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -402,11 +400,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
QLB = ladder_permutate(qB.cpu()).cuda() QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) kernel(A, QLB, C)
mod(A, QLB, C) latency = profiler.do_bench(warmup=25)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang import tilelang
from tilelang import Profiler
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import * from tilelang.autotuner import *
from tilelang import tvm
from tvm import tir from tvm import tir
import itertools import itertools
import torch import torch
import argparse import argparse
from functools import partial
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
...@@ -103,11 +97,10 @@ def test_fp4_fp16_convert_close(): ...@@ -103,11 +97,10 @@ def test_fp4_fp16_convert_close():
"float16", "float16",
) )
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[1])
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B) tl_out = kernel(B)
ref_out = torch_convert(B) ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass") print("Pass")
...@@ -291,14 +284,14 @@ if __name__ == "__main__": ...@@ -291,14 +284,14 @@ if __name__ == "__main__":
program = matmul( program = matmul(
M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)( M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500) latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment