"...resnet50_tensorflow.git" did not exist on "5fdf878dba9ed5a4ef33c9af25d3654cc7874468"
Commit b1ba0cc8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce pass_configs parameter for kernel Caching (#452)

* [Enhancement] Introduce pass_configs parameter for kernel compilation

* Added a new `pass_configs` parameter to the `tilelang.compile` function to allow for more flexible kernel compilation configurations.
* Updated related classes and methods to accommodate the new parameter, ensuring compatibility across the codebase.
* Enhanced the `torch_assert_close` function to include customizable tensor names for better debugging output.
* Refactored input handling in example scripts to streamline the process of obtaining inputs for kernel execution.

* lint fix
parent 0a8c8b99
......@@ -191,7 +191,6 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
T.copy(po_shared, po_local)
for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
# T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M):
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
for i, j in T.Parallel(block_M, dim):
......@@ -305,7 +304,8 @@ if __name__ == "__main__":
BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, causal=causal)
kernel = tilelang.compile(program, out_idx=[5], target="cuda", execution_backend="dlpack")
kernel = tilelang.compile(
program, out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
......@@ -183,32 +183,10 @@ if __name__ == "__main__":
kernel = tilelang.compile(program, out_idx=[4])
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
ins = []
for i in range(len(kernel.params)):
if i not in kernel.result_idx:
shape = [int(x) for x in kernel.params[i].shape]
ins.append(torch.empty(shape, device="cuda", dtype=torch.float16).normal_(-0.1, 0.1))
ins = profiler._get_inputs()
ref_outs = ref_program(*ins)
torch.cuda.synchronize()
lib_outs = kernel(*ins)
torch.cuda.synchronize()
if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
assert len(lib_outs) == len(ref_outs)
from tilelang.utils.tensor import torch_assert_close
for lhs, rhs in zip(lib_outs, ref_outs):
torch_assert_close(
lhs,
rhs,
rtol=0.01,
atol=0.01,
max_mismatched_ratio=0.01,
)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(n_warmup=10, n_repeat=10, profiler="torch")
......
......@@ -66,6 +66,7 @@ class KernelCache:
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
pass_configs: dict = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
......@@ -91,6 +92,7 @@ class KernelCache:
"target": str(target),
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
"pass_configs": pass_configs,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
......@@ -136,7 +138,9 @@ class KernelCache:
execution_backend=execution_backend,
args=args,
target=target,
target_host=target_host)
target_host=target_host,
pass_configs=pass_configs,
)
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
......
......@@ -117,6 +117,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
adapter.pass_configs = pass_configs
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......
......@@ -285,6 +285,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
adapter.pass_configs = pass_configs
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......
......@@ -144,6 +144,7 @@ class JITKernel(object):
target=target,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
instance.torch_function = instance.adapter.func
return instance
......@@ -247,6 +248,7 @@ class JITKernel(object):
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
pass_configs: Optional[Dict[str, Any]] = None,
) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
......@@ -262,6 +264,7 @@ class JITKernel(object):
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
......@@ -271,6 +274,7 @@ class JITKernel(object):
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
......
......@@ -117,6 +117,8 @@ class Profiler:
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
base_name="tilelang",
ref_name="ref",
)
def assert_consistent(self, repeat=10):
......
......@@ -218,6 +218,8 @@ def torch_assert_close(
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
base_name: str = "LHS",
ref_name: str = "RHS",
):
"""
Custom function to assert that two tensors are "close enough," allowing a specified
......@@ -293,7 +295,7 @@ def torch_assert_close(
f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}"
f"\nLHS: {tensor_a}"
f"\nRHS: {tensor_b}")
f"\n{base_name}: {tensor_a}"
f"\n{ref_name}: {tensor_b}")
else:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment