Commit 79ea77e8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Simplify GEMM example with direct kernel compilation (#191)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic

* Fix Copy operation handling for scalar and multi-dimensional tensors

- Add special handling for scalar tensor copy operations
- Enhance error reporting in MakeIndices method with more detailed diagnostic information
- Improve SIMT loop generation to support zero-dimensional tensors
- Add explicit check and handling for scalar tensor scenarios

* Refactor Copy operation code formatting and improve readability

- Improve code formatting in MakeIndices and MakeSIMTLoop methods
- Add line breaks to enhance readability of complex ICHECK statements
- Simplify code structure in scalar tensor handling
- Remove unnecessary whitespace and improve code alignment

* Simplify GEMM example with direct kernel compilation

- Update copyright header to Tile-AI Corporation
- Remove Profiler import and usage
- Replace tilelang.lower() with tilelang.compile()
- Simplify kernel execution workflow
- Update kernel source retrieval method
parent 454248c7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import Profiler
import tilelang.language as T
......@@ -34,16 +30,14 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
rt_mod, params = tilelang.lower(func)
profiler = Profiler(rt_mod, params, result_idx=[2])
kernel = tilelang.compile(func, out_idx=-1)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
c = kernel(a, b)
ref_c = a @ b
......@@ -53,4 +47,4 @@ print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
print(rt_mod.imported_modules[0].get_source())
print(kernel.get_kernel_source())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment