Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
......@@ -56,8 +56,8 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9)
def test_assert_matmul():
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "e4m3_float8", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "e5m2_float8", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e4m3", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e5m2", "float32", "float32")
if __name__ == "__main__":
......
......@@ -39,8 +39,8 @@ def tl_matmul(
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -51,7 +51,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
......@@ -216,8 +216,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
if __name__ == "__main__":
......
......@@ -166,8 +166,8 @@ def evaluate_gemv_simt(
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False)
if __name__ == "__main__":
......
......@@ -40,8 +40,8 @@ def tl_matmul(
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -52,7 +52,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
......@@ -228,8 +228,8 @@ def test_assert_tl_matmul_bfloat16():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
if __name__ == "__main__":
......
......@@ -173,8 +173,8 @@ def test_gemv_simt():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt_fp8():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False)
if __name__ == "__main__":
......
......@@ -14,9 +14,10 @@ from tilelang.intrinsics.mma_macro_generator import (
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(42)
tilelang.disable_cache()
@simplify_prim_func
# @simplify_prim_func
def tl_matmul(
M,
N,
......@@ -164,7 +165,13 @@ def tl_matmul(
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
kernel = tilelang.compile(
matmul,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
src_code = kernel.get_kernel_source()
......@@ -400,4 +407,5 @@ def test_assert_tl_matmul_weight_only_transform():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
......@@ -27,7 +27,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], X_shared)
aliased_offset = T.int32()
T.let(aliased_offset, ko * block_K)
T.copy(A[by * block_M, aliased_offset], X_shared)
# Demonstrate parallelized copy from global to shared for B
T.copy(B[bx * block_N, ko * block_K], B_shared[:block_N, :block_K])
......
......@@ -39,7 +39,6 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16",
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
print(kernel.get_kernel_source())
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
ref_b = torch.zeros_like(a)
......
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment