Commit 3de9f13c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Introduce KernelParam integration across modules (#223)

* [Refactor] Update KernelParam integration across modules

- Replaced instances of TensorType with KernelParam in various modules to standardize parameter handling.
- Updated JITKernel, BaseKernelAdapter, and CythonKernelAdapter to utilize KernelParam for improved type consistency.
- Enhanced Profiler class to include KernelParam in its parameters, ensuring better integration with the new parameter structure.
- Adjusted tensor handling in utility functions to accommodate the new KernelParam type, improving overall code clarity and maintainability.
- Updated copyright headers to reflect the correct organization.

* [Refactor] Clean up whitespace in kernel, profiler, and tensor modules

- Added blank lines for improved readability in kernel.py, __init__.py, and tensor.py.
- Enhanced code clarity by ensuring consistent formatting across these modules.

* [Enhancement] Add detailed docstrings to KernelParam and Profiler classes

- Enhanced KernelParam class with comprehensive docstrings for better understanding of its purpose and methods.
- Updated Profiler class to include detailed docstrings for its attributes and methods, improving code documentation and usability.
- Removed unused do_bench function to streamline the profiler module and improve clarity.

* [Refactor] Update type hints in do_bench function and clean up whitespace in profiler module

- Changed type hints for grad_to_none and quantiles parameters in do_bench function to use Optional for better clarity.
- Added a blank line in __init__.py for improved readability and consistency in the profiler module.

* [Refactor] Update type hint in do_bench function for consistency

- Changed the return type hint in the do_bench function from a union type to a more explicit List type for better clarity and consistency in type annotations.

* [Refactor] Update return type hint in do_bench function for clarity

- Changed the return type hint in the do_bench function from a union type to Union[float, List[float]] for improved clarity and consistency in type annotations.

* [Enhancement] Add func property to Profiler class for adapter access

- Introduced a new property `func` in the Profiler class to provide access to the adapter, ensuring that the adapter is set before retrieval. This enhancement improves the usability of the Profiler class by allowing easier access to the adapter functionality.

* [Refactor] Update kernel compilation and profiling in tests

- Replaced instances of `TL.lower` and `TL.Profiler` with `tilelang.compile` and the new profiler interface across multiple test files.
- Enhanced the kernel compilation process to utilize the updated API, improving consistency and maintainability in the testing framework.
- Updated assertions to use the new profiler methods for better clarity and functionality in performance testing.

* [Refactor] Simplify kernel invocation and remove unused parameters in tests

- Updated the kernel invocation in `test_tilelang_dynamic_symbolic.py` to directly assign the result to `C`, improving clarity.
- Removed the `execution_backend` parameter from `tilelang.compile` calls in `test_tilelang_jit_callback.py` and `test_tilelang_jit_gemm.py` for consistency with the updated API.
- Commented out the call to `tilelang.testing.main()` in `test_tilelang_jit_callback.py` and replaced it with a direct call to `test_gemm_jit_kernel()` to streamline test execution.
- Adjusted the dtype mapping in `TorchDLPackKernelAdapter` to use the parameter's dtype directly, enhancing code simplicity.

* [Refactor] Remove unused imports in test files for cleaner code

- Eliminated unnecessary imports of `tilelang` as `TL` in various test files to enhance code clarity and maintainability.
- Updated multiple test files to streamline the codebase and reduce potential confusion from unused references.

* [Refactor] Simplify kernel invocation in tilelang kernel test

- Updated the kernel invocation in `test_tilelang_kernel_bf16_gemm_mma.py` to directly assign the result to `C`, enhancing code clarity and consistency with recent changes in the API.

* [Refactor] Simplify kernel invocation in tilelang kernel tests

- Updated kernel invocations in multiple test files to directly assign the result to `C`, improving code clarity and consistency with the updated API.
- Removed unnecessary initialization of `C` as a zero tensor, streamlining the code further.

* [Refactor] Update kernel invocation in tilelang transform tests

- Replaced the use of `TL.Profiler` with `tilelang.compile` in `test_tilelang_transform_simplify.py`, enhancing code clarity and consistency with the updated API.
- Streamlined the kernel invocation process by directly assigning the result to `C`, improving readability and maintainability of the test code.
parent 6bc8d6d3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
......@@ -166,8 +162,8 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -180,11 +176,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
kernel(A, B, C)
mod(A, B, C)
profiler = kernel.get_profiler()
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
......@@ -84,8 +81,8 @@ def run_gemm(
num_threads,
k_pack=k_pack,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tl.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......@@ -98,7 +95,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
......@@ -173,8 +169,8 @@ def tl_matmul_macro(
def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -183,9 +179,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
......@@ -264,14 +258,14 @@ def assert_tl_matmul_block_correctness(
num_stages,
num_threads,
)
mod, params = TL.lower(program)
kernel = tilelang.compile(program, out_idx=[2])
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
def ref_program(A, B):
import torch
......
......@@ -93,7 +93,7 @@ def run_gemm(
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1)
kernel_source = matmul_kernel.get_kernel_source()
......@@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......@@ -236,4 +236,5 @@ def test_gemm_jit_kernel():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_gemm_jit_kernel()
......@@ -31,7 +31,6 @@ def matmul(
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
execution_backend="dlpack",
)
@T.prim_func
def main(
......@@ -206,7 +205,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1)
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
......@@ -182,8 +178,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -203,11 +201,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
......@@ -69,8 +65,8 @@ def run_conv(N,
program = convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M,
block_N, block_K, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......@@ -81,7 +77,7 @@ def run_conv(N,
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C.to(torch.__getattribute__(out_dtype))
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_conv_f16f16f16_k3s1d1p1():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType, tir
import tilelang as TL
import tilelang.language as T
from tilelang import JITKernel, Profiler
tilelang.testing.set_random_seed(0)
......@@ -102,11 +98,10 @@ def test_fp4_fp16_convert_close():
"float16",
)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
......@@ -202,7 +197,7 @@ def assert_simple_impl_float16xfp4_gemm(M,
func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K,
num_stages, threads)
torch_func = JITKernel(func, [2])
torch_func = tilelang.compile(func, out_idx=[2])
profiler = torch_func.get_profiler()
profiler.assert_allclose(ref_program)
......@@ -318,10 +313,10 @@ def run_gemm(
num_threads,
)
mod, params = TL.lower(program)
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
out = mod.run_once()
out = profiler.run_once()
assert out is not None
def ref_program(A, qB):
......@@ -337,7 +332,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program)
profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
......@@ -566,8 +561,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -605,11 +602,9 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, QLB, C)
kernel(A, QLB, C)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch
......@@ -53,15 +52,15 @@ def run_elementwise_add(
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
C = torch.add(A, B)
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_elementwise_add_f32():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
tilelang.testing.set_random_seed(0)
......@@ -130,8 +126,8 @@ def run_chunk_scan(batch,
program = chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, block_Dstate, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [7], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[7])
profiler = kernel.get_profiler()
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
import torch
......@@ -182,7 +178,7 @@ def run_chunk_scan(batch,
out = out + x * D
return out
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def chunk_state_fwd(batch,
......@@ -275,8 +271,8 @@ def run_chunk_state(batch,
program = chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[4])
profiler = kernel.get_profiler()
def ref_program(B, x, dt, dA_cumsum):
"""
......@@ -313,7 +309,7 @@ def run_chunk_state(batch,
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
dt.to(x.dtype), x)
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_chunk_scan():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
......@@ -181,8 +177,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
......@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def matmul(
......@@ -85,8 +81,8 @@ def run_gemm(
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......@@ -99,7 +95,7 @@ def run_gemm(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nn():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
......@@ -182,8 +178,10 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
......@@ -201,13 +199,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func
......@@ -142,8 +138,10 @@ def tl_matmul_simt(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
......@@ -155,13 +153,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
C = kernel(A, B)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
......@@ -168,21 +164,21 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(compressed_A, compressed_B, C)
C = kernel(compressed_A, compressed_B)
print(C)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
print(latency)
# Ensure that the latency is not None
assert latency is not None
......@@ -358,15 +354,16 @@ def tl_matmul_weight_only_transform(
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
transform_b = 3
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
......@@ -380,13 +377,10 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
)
ladder_permutate = tilelang.ops.LadderPermutate(ladder_permutate_config)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
LB = ladder_permutate(compressed_B.cpu()).cuda()
C = kernel(compressed_A, LB)
mod(compressed_A, LB, C)
latency = mod.do_bench(mod.func, warmup=25)
latency = profiler.do_bench()
print(f"Latency: {latency}")
# Ensure that the latency is not None
assert latency is not None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
......@@ -132,8 +128,8 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=
program = flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages,
threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler()
def ref_program(Q, K, V):
import torch
......@@ -150,7 +146,7 @@ def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_mha_causal_dim64():
......
import tilelang.testing
import tilelang as tl
def clamp(
......@@ -35,8 +34,8 @@ def run_clamp(
):
program = clamp(N, block_N, dtype, min, max)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[1])
profiler = kernel.get_profiler()
def ref_program(A):
import torch
......@@ -86,8 +85,8 @@ def run_clamp_v2(
block_N,
dtype,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[1])
profiler = kernel.get_profiler()
def ref_program(A):
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
from tilelang import primitives as P
......@@ -85,9 +81,9 @@ def run_matmul_ssr(
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
print(mod.get_kernel_source())
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
print(kernel.get_kernel_source())
def ref_program(A, B):
import torch
......@@ -100,7 +96,7 @@ def run_matmul_ssr(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nt_ssr():
......@@ -206,9 +202,9 @@ def run_matmul_rsr(
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
print(mod.get_kernel_source())
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
print(kernel.get_kernel_source())
def ref_program(A, B):
import torch
......@@ -221,7 +217,7 @@ def run_matmul_rsr(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# TODO(lei): Fix the test case in future release
......@@ -329,8 +325,8 @@ def run_matmul_rrr(
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......@@ -343,7 +339,7 @@ def run_matmul_rrr(
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# def test_gemm_f16f16f16_nt_rrr():
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def matmul(
......@@ -85,8 +81,8 @@ def run_gemm(
num_threads,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......@@ -226,8 +222,8 @@ def run_gemm_rs(
num_threads,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
......
......@@ -76,15 +76,12 @@ def test_matmul():
func = matmul(1024, 1024, 1024, 128, 128, 32)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
mod = tl.transform.Simplify()(mod)
rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False)
# TODO Profiler only support TensorType, not dynamic variable
profiler = tl.Profiler(rt_mod, params, result_idx=[2])
kernel = tl.compile(mod["main"], out_idx=[2])
import torch
a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
c = profiler(a, b)
c = kernel(a, b)
ref_c = a @ b
ref_c = ref_c.float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment