"testing/python/jit/test_tilelang_jit_cutedsl.py" did not exist on "a7730272e4aeeed198b855b7f36ef7ac88cdd76b"
Commit d34601ab authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Add dynamic shape support with out_idx in Cython JIT kernel compilation (#185)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic
parent c2192780
......@@ -407,5 +407,62 @@ def test_cython_dynamic_shape():
"float16", 128, 256, 32, 2)
def run_cython_dynamic_shape_with_out_idx(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=-1)
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = matmul_kernel(tensor_a, tensor_b)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -7,6 +7,7 @@ cimport cython
import ctypes
from libc.stdint cimport int64_t, uintptr_t
from libc.stdlib cimport malloc, free
from tvm import tir
cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference
......@@ -62,7 +63,18 @@ cdef class CythonKernelWrapper:
if i in self.result_idx:
# Create empty output tensor with specified dtype and shape
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
shape = []
for s in self.params[i].shape:
if isinstance(s, tir.Var):
# find the corresponding input tensor and shape dimension
assert s in self.dynamic_symbolic_map, f"Dynamic symbolic dimension \
{s} not found in dynamic_symbolic_map"
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif isinstance(s, (tir.IntImm, int)):
shape.append(int(s))
else:
raise ValueError(f"Unsupported shape type: {type(s)}")
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment