You need to sign in or sign up before continuing.
Commit 3de9f13c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Introduce KernelParam integration across modules (#223)

* [Refactor] Update KernelParam integration across modules

- Replaced instances of TensorType with KernelParam in various modules to standardize parameter handling.
- Updated JITKernel, BaseKernelAdapter, and CythonKernelAdapter to utilize KernelParam for improved type consistency.
- Enhanced Profiler class to include KernelParam in its parameters, ensuring better integration with the new parameter structure.
- Adjusted tensor handling in utility functions to accommodate the new KernelParam type, improving overall code clarity and maintainability.
- Updated copyright headers to reflect the correct organization.

* [Refactor] Clean up whitespace in kernel, profiler, and tensor modules

- Added blank lines for improved readability in kernel.py, __init__.py, and tensor.py.
- Enhanced code clarity by ensuring consistent formatting across these modules.

* [Enhancement] Add detailed docstrings to KernelParam and Profiler classes

- Enhanced KernelParam class with comprehensive docstrings for better understanding of its purpose and methods.
- Updated Profiler class to include detailed docstrings for its attributes and methods, improving code documentation and usability.
- Removed unused do_bench function to streamline the profiler module and improve clarity.

* [Refactor] Update type hints in do_bench function and clean up whitespace in profiler module

- Changed type hints for grad_to_none and quantiles parameters in do_bench function to use Optional for better clarity.
- Added a blank line in __init__.py for improved readability and consistency in the profiler module.

* [Refactor] Update type hint in do_bench function for consistency

- Changed the return type hint in the do_bench function from a union type to a more explicit List type for better clarity and consistency in type annotations.

* [Refactor] Update return type hint in do_bench function for clarity

- Changed the return type hint in the do_bench function from a union type to Union[float, List[float]] for improved clarity and consistency in type annotations.

* [Enhancement] Add func property to Profiler class for adapter access

- Introduced a new property `func` in the Profiler class to provide access to the adapter, ensuring that the adapter is set before retrieval. This enhancement improves the usability of the Profiler class by allowing easier access to the adapter functionality.

* [Refactor] Update kernel compilation and profiling in tests

- Replaced instances of `TL.lower` and `TL.Profiler` with `tilelang.compile` and the new profiler interface across multiple test files.
- Enhanced the kernel compilation process to utilize the updated API, improving consistency and maintainability in the testing framework.
- Updated assertions to use the new profiler methods for better clarity and functionality in performance testing.

* [Refactor] Simplify kernel invocation and remove unused parameters in tests

- Updated the kernel invocation in `test_tilelang_dynamic_symbolic.py` to directly assign the result to `C`, improving clarity.
- Removed the `execution_backend` parameter from `tilelang.compile` calls in `test_tilelang_jit_callback.py` and `test_tilelang_jit_gemm.py` for consistency with the updated API.
- Commented out the call to `tilelang.testing.main()` in `test_tilelang_jit_callback.py` and replaced it with a direct call to `test_gemm_jit_kernel()` to streamline test execution.
- Adjusted the dtype mapping in `TorchDLPackKernelAdapter` to use the parameter's dtype directly, enhancing code simplicity.

* [Refactor] Remove unused imports in test files for cleaner code

- Eliminated unnecessary imports of `tilelang` as `TL` in various test files to enhance code clarity and maintainability.
- Updated multiple test files to streamline the codebase and reduce potential confusion from unused references.

* [Refactor] Simplify kernel invocation in tilelang kernel test

- Updated the kernel invocation in `test_tilelang_kernel_bf16_gemm_mma.py` to directly assign the result to `C`, enhancing code clarity and consistency with recent changes in the API.

* [Refactor] Simplify kernel invocation in tilelang kernel tests

- Updated kernel invocations in multiple test files to directly assign the result to `C`, improving code clarity and consistency with the updated API.
- Removed unnecessary initialization of `C` as a zero tensor, streamlining the code further.

* [Refactor] Update kernel invocation in tilelang transform tests

- Replaced the use of `TL.Profiler` with `tilelang.compile` in `test_tilelang_transform_simplify.py`, enhancing code clarity and consistency with the updated API.
- Streamlined the kernel invocation process by directly assigning the result to `C`, improving readability and maintainability of the test code.
parent 6bc8d6d3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import ( from tilelang.intrinsics.mfma_macro_generator import (
...@@ -166,8 +162,8 @@ def tl_matmul( ...@@ -166,8 +162,8 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul)
src_code = mod.imported_modules[0].get_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -180,11 +176,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa ...@@ -180,11 +176,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) kernel(A, B, C)
mod(A, B, C) profiler = kernel.get_profiler()
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -84,8 +81,8 @@ def run_gemm( ...@@ -84,8 +81,8 @@ def run_gemm(
num_threads, num_threads,
k_pack=k_pack, k_pack=k_pack,
) )
mod, params = tl.lower(program) kernel = tl.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -98,7 +95,7 @@ def run_gemm( ...@@ -98,7 +95,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
...@@ -173,8 +169,8 @@ def tl_matmul_macro( ...@@ -173,8 +169,8 @@ def tl_matmul_macro(
def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -183,9 +179,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype ...@@ -183,9 +179,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) C = kernel(A, B)
mod(A, B, C)
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
...@@ -264,14 +258,14 @@ def assert_tl_matmul_block_correctness( ...@@ -264,14 +258,14 @@ def assert_tl_matmul_block_correctness(
num_stages, num_stages,
num_threads, num_threads,
) )
mod, params = TL.lower(program)
kernel = tilelang.compile(program, out_idx=[2])
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) C = kernel(A, B)
mod(A, B, C)
def ref_program(A, B): def ref_program(A, B):
import torch import torch
......
...@@ -93,7 +93,7 @@ def run_gemm( ...@@ -93,7 +93,7 @@ def run_gemm(
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack") matmul_kernel = tilelang.compile(program, out_idx=-1)
kernel_source = matmul_kernel.get_kernel_source() kernel_source = matmul_kernel.get_kernel_source()
...@@ -196,7 +196,7 @@ def run_gemm_jit_kernel( ...@@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack") matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
...@@ -236,4 +236,5 @@ def test_gemm_jit_kernel(): ...@@ -236,4 +236,5 @@ def test_gemm_jit_kernel():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
test_gemm_jit_kernel()
...@@ -31,7 +31,6 @@ def matmul( ...@@ -31,7 +31,6 @@ def matmul(
@tilelang.jit( @tilelang.jit(
out_idx=-1, # create the output tensor during runtime out_idx=-1, # create the output tensor during runtime
execution_backend="dlpack",
) )
@T.prim_func @T.prim_func
def main( def main(
...@@ -206,7 +205,7 @@ def run_gemm_jit_kernel( ...@@ -206,7 +205,7 @@ def run_gemm_jit_kernel(
num_threads, num_threads,
) )
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack") matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
...@@ -182,8 +178,10 @@ def tl_matmul( ...@@ -182,8 +178,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -203,11 +201,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -203,11 +201,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) C = kernel(A, B)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -69,8 +65,8 @@ def run_conv(N, ...@@ -69,8 +65,8 @@ def run_conv(N,
program = convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M, program = convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M,
block_N, block_K, num_stages, threads) block_N, block_K, num_stages, threads)
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -81,7 +77,7 @@ def run_conv(N, ...@@ -81,7 +77,7 @@ def run_conv(N,
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C.to(torch.__getattribute__(out_dtype)) return C.to(torch.__getattribute__(out_dtype))
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_conv_f16f16f16_k3s1d1p1(): def test_conv_f16f16f16_k3s1d1p1():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType, tir from tvm import DataType, tir
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang import JITKernel, Profiler
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -102,11 +98,10 @@ def test_fp4_fp16_convert_close(): ...@@ -102,11 +98,10 @@ def test_fp4_fp16_convert_close():
"float16", "float16",
) )
mod, params = tilelang.lower(program) kernel = tilelang.compile(program, out_idx=[1])
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B) tl_out = kernel(B)
ref_out = torch_convert(B) ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out) assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass") print("Pass")
...@@ -202,7 +197,7 @@ def assert_simple_impl_float16xfp4_gemm(M, ...@@ -202,7 +197,7 @@ def assert_simple_impl_float16xfp4_gemm(M,
func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K,
num_stages, threads) num_stages, threads)
torch_func = JITKernel(func, [2]) torch_func = tilelang.compile(func, out_idx=[2])
profiler = torch_func.get_profiler() profiler = torch_func.get_profiler()
profiler.assert_allclose(ref_program) profiler.assert_allclose(ref_program)
...@@ -318,10 +313,10 @@ def run_gemm( ...@@ -318,10 +313,10 @@ def run_gemm(
num_threads, num_threads,
) )
mod, params = TL.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) profiler = kernel.get_profiler()
out = mod.run_once() out = profiler.run_once()
assert out is not None assert out is not None
def ref_program(A, qB): def ref_program(A, qB):
...@@ -337,7 +332,7 @@ def run_gemm( ...@@ -337,7 +332,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program) profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas") @tvm.testing.requires_package("bitblas")
...@@ -566,8 +561,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -566,8 +561,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -605,11 +602,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -605,11 +602,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
QLB = ladder_permutate(qB.cpu()).cuda() QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) kernel(A, QLB, C)
mod(A, QLB, C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
import torch import torch
...@@ -53,15 +52,15 @@ def run_elementwise_add( ...@@ -53,15 +52,15 @@ def run_elementwise_add(
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
C = torch.add(A, B) C = torch.add(A, B)
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_elementwise_add_f32(): def test_elementwise_add_f32():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
import tilelang.language as T import tilelang.language as T
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -130,8 +126,8 @@ def run_chunk_scan(batch, ...@@ -130,8 +126,8 @@ def run_chunk_scan(batch,
program = chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M, program = chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, block_Dstate, num_stages, threads) block_N, block_K, block_Dstate, num_stages, threads)
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[7])
mod = tl.Profiler(mod, params, [7], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
import torch import torch
...@@ -182,7 +178,7 @@ def run_chunk_scan(batch, ...@@ -182,7 +178,7 @@ def run_chunk_scan(batch,
out = out + x * D out = out + x * D
return out return out
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def chunk_state_fwd(batch, def chunk_state_fwd(batch,
...@@ -275,8 +271,8 @@ def run_chunk_state(batch, ...@@ -275,8 +271,8 @@ def run_chunk_state(batch,
program = chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M, program = chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, num_stages, threads) block_N, block_K, num_stages, threads)
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[4])
mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(B, x, dt, dA_cumsum): def ref_program(B, x, dt, dA_cumsum):
""" """
...@@ -313,7 +309,7 @@ def run_chunk_state(batch, ...@@ -313,7 +309,7 @@ def run_chunk_state(batch,
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
dt.to(x.dtype), x) dt.to(x.dtype), x)
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_chunk_scan(): def test_chunk_scan():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
...@@ -181,8 +177,10 @@ def tl_matmul( ...@@ -181,8 +177,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
print(src_code) print(src_code)
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) C = kernel(A, B)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
def matmul( def matmul(
...@@ -85,8 +81,8 @@ def run_gemm( ...@@ -85,8 +81,8 @@ def run_gemm(
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -99,7 +95,7 @@ def run_gemm( ...@@ -99,7 +95,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nn(): def test_gemm_f16f16f16_nn():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
...@@ -182,8 +178,10 @@ def tl_matmul( ...@@ -182,8 +178,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5 A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5 B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype) C = kernel(A, B)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
...@@ -142,8 +138,10 @@ def tl_matmul_simt( ...@@ -142,8 +138,10 @@ def tl_matmul_simt(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
print(src_code) print(src_code)
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
...@@ -155,13 +153,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -155,13 +153,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) C = kernel(A, B)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch import torch
import torch.backends import torch.backends
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import ( from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,)
...@@ -168,21 +164,21 @@ def tl_matmul( ...@@ -168,21 +164,21 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) C = kernel(compressed_A, compressed_B)
mod(compressed_A, compressed_B, C)
print(C) print(C)
latency = mod.do_bench(mod.func, warmup=25) latency = profiler.do_bench()
print(latency) print(latency)
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
...@@ -358,15 +354,16 @@ def tl_matmul_weight_only_transform( ...@@ -358,15 +354,16 @@ def tl_matmul_weight_only_transform(
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = mod.imported_modules[0].get_source() profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
transform_b = 3 transform_b = 3
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
...@@ -380,13 +377,10 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ...@@ -380,13 +377,10 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
) )
ladder_permutate = tilelang.ops.LadderPermutate(ladder_permutate_config) ladder_permutate = tilelang.ops.LadderPermutate(ladder_permutate_config)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
LB = ladder_permutate(compressed_B.cpu()).cuda() LB = ladder_permutate(compressed_B.cpu()).cuda()
C = kernel(compressed_A, LB)
mod(compressed_A, LB, C) latency = profiler.do_bench()
latency = mod.do_bench(mod.func, warmup=25)
print(f"Latency: {latency}") print(f"Latency: {latency}")
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -132,8 +128,8 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages= ...@@ -132,8 +128,8 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=
program = flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages, program = flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages,
threads) threads)
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[3])
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(Q, K, V): def ref_program(Q, K, V):
import torch import torch
...@@ -150,7 +146,7 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages= ...@@ -150,7 +146,7 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output return output
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_mha_causal_dim64(): def test_mha_causal_dim64():
......
import tilelang.testing import tilelang.testing
import tilelang as tl
def clamp( def clamp(
...@@ -35,8 +34,8 @@ def run_clamp( ...@@ -35,8 +34,8 @@ def run_clamp(
): ):
program = clamp(N, block_N, dtype, min, max) program = clamp(N, block_N, dtype, min, max)
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[1])
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A): def ref_program(A):
import torch import torch
...@@ -86,8 +85,8 @@ def run_clamp_v2( ...@@ -86,8 +85,8 @@ def run_clamp_v2(
block_N, block_N,
dtype, dtype,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[1])
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A): def ref_program(A):
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
from tilelang import primitives as P from tilelang import primitives as P
...@@ -85,9 +81,9 @@ def run_matmul_ssr( ...@@ -85,9 +81,9 @@ def run_matmul_ssr(
num_stages, num_stages,
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
print(mod.get_kernel_source()) print(kernel.get_kernel_source())
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -100,7 +96,7 @@ def run_matmul_ssr( ...@@ -100,7 +96,7 @@ def run_matmul_ssr(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nt_ssr(): def test_gemm_f16f16f16_nt_ssr():
...@@ -206,9 +202,9 @@ def run_matmul_rsr( ...@@ -206,9 +202,9 @@ def run_matmul_rsr(
num_stages, num_stages,
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
print(mod.get_kernel_source()) print(kernel.get_kernel_source())
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -221,7 +217,7 @@ def run_matmul_rsr( ...@@ -221,7 +217,7 @@ def run_matmul_rsr(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# TODO(lei): Fix the test case in future release # TODO(lei): Fix the test case in future release
...@@ -329,8 +325,8 @@ def run_matmul_rrr( ...@@ -329,8 +325,8 @@ def run_matmul_rrr(
num_stages, num_stages,
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -343,7 +339,7 @@ def run_matmul_rrr( ...@@ -343,7 +339,7 @@ def run_matmul_rrr(
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# def test_gemm_f16f16f16_nt_rrr(): # def test_gemm_f16f16f16_nt_rrr():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl
def matmul( def matmul(
...@@ -85,8 +81,8 @@ def run_gemm( ...@@ -85,8 +81,8 @@ def run_gemm(
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
...@@ -226,8 +222,8 @@ def run_gemm_rs( ...@@ -226,8 +222,8 @@ def run_gemm_rs(
num_threads, num_threads,
) )
mod, params = tl.lower(program) kernel = tilelang.compile(program, out_idx=[2])
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch import torch
......
...@@ -76,15 +76,12 @@ def test_matmul(): ...@@ -76,15 +76,12 @@ def test_matmul():
func = matmul(1024, 1024, 1024, 128, 128, 32) func = matmul(1024, 1024, 1024, 128, 128, 32)
mod = tvm.IRModule({func.attrs["global_symbol"]: func}) mod = tvm.IRModule({func.attrs["global_symbol"]: func})
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
kernel = tl.compile(mod["main"], out_idx=[2])
rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False)
# TODO Profiler only support TensorType, not dynamic variable
profiler = tl.Profiler(rt_mod, params, result_idx=[2])
import torch import torch
a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
c = profiler(a, b) c = kernel(a, b)
ref_c = a @ b ref_c = a @ b
ref_c = ref_c.float() ref_c = ref_c.float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment