Commit 2d0c4169 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support float variable as arguments (#250)

* [Enhancement] Add matrix multiplication functions for integer and float variables in Cython JIT

- Introduced `matmul_int_variable` and `matmul_float_variable` functions to support matrix multiplication with dynamic shapes and additional parameters.
- Implemented corresponding `run_matmul_int_variable` and `run_matmul_float_variable` functions for testing.
- Updated test cases to validate the new matrix multiplication implementations.
- Enhanced error handling in library initialization and compilation processes across various modules.
- Improved dynamic memory handling in CUDA kernel initialization to provide better error reporting.

* lint fix

* optimize
parent 4fcf6abe
......@@ -476,5 +476,149 @@ def test_cython_dynamic_shape_with_out_idx():
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def matmul_int_variable(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
offset: T.int32,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = C_local[i, j] + offset
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads):
program = matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads)
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
tensor_c = matmul_kernel(tensor_a, tensor_b, 1)
tensor_ref_c = torch.matmul(tensor_a, tensor_b).to(out_dtype) + 1
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, rtol=1e-2, atol=1e-2)
def test_matmul_int_variable():
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16",
"float32", 0, 128)
def matmul_float_variable(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
offset: T.float32,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = C_local[i, j] + offset
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads):
program = matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads)
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
tensor_c = matmul_kernel(tensor_a, tensor_b, 1.0)
tensor_ref_c = torch.matmul(tensor_a, tensor_b).to(out_dtype) + 1.0
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, rtol=1e-2, atol=1e-2)
def test_matmul_float_variable():
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16",
"float32", 0, 128)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -238,5 +238,7 @@ def lower(
if enable_host_codegen:
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source())
......@@ -200,11 +200,11 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib_generator.compile_lib()
self.lib = self.lib_generator.load_lib()
try:
self.lib.init()
except Exception as e:
raise Exception(
f"Failed to initialize the compiled library for {self.target}: {e}") from e
self.lib.get_last_error.restype = ctypes.c_char_p
result = self.lib.init()
if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8')
raise RuntimeError(f"Initialization failed: {error_msg}")
self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib)
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
......@@ -246,11 +246,11 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
try:
adapter.lib.init()
except Exception as e:
raise Exception(
f"Failed to initialize the compiled library for {adapter.target}: {e}") from e
adapter.lib.get_last_error.restype = ctypes.c_char_p
result = adapter.lib.init()
if result != 0:
error_msg = adapter.lib.get_last_error().decode('utf-8')
raise RuntimeError(f"Initialization failed: {error_msg}")
adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params,
adapter.lib)
......
......@@ -107,6 +107,10 @@ cdef class CythonKernelWrapper:
elif isinstance(tensor_list[i], int):
# Dynamic symbolics which are passed as integer arguments
call_args.append(tensor_list[i])
elif isinstance(tensor_list[i], float):
call_args.append(ctypes.c_float(tensor_list[i]))
elif isinstance(tensor_list[i], bool):
call_args.append(ctypes.c_bool(tensor_list[i]))
else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
......
......@@ -90,12 +90,12 @@ class LibraryGenerator(object):
src.flush()
try:
ret = subprocess.run(command, timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(f"Compilation Timeout! {command}")
return None
except Exception as e:
raise RuntimeError(f"Compile kernel failed because of {e}") from e
if ret.returncode != 0:
logger.warning(f"Compilation Failed! {command}")
return None
raise RuntimeError(f"Compilation Failed! {command}")
self.srcpath = src.name
self.libpath = libpath
......
......@@ -9,12 +9,25 @@ import logging
import textwrap
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
cudaError_t result = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result != CUDA_SUCCESS) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result));
return -1;
}}
"""
PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {{
return error_buf;
}}
extern "C" int init() {{
error_buf[0] = '\\0';
{0}
return 0;
}}
"""
......
......@@ -187,8 +187,15 @@ class JITKernel(object):
pass_configs = self.pass_configs
# Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "dlpack"
enable_device_compile = execution_backend == "dlpack"
with tvm.transform.PassContext(opt_level=3, config=pass_configs):
artifact = tilelang.lower(tilelang_func, target=target, target_host=target_host)
artifact = tilelang.lower(
tilelang_func,
target=target,
target_host=target_host,
enable_host_codegen=enable_host_codegen,
enable_device_compile=enable_device_compile)
self.artifact = artifact
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment