"...CHARM/git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "608ff5810dd2fea1161acbd936cfb4d6bf4cfb28"
Commit cce6aed8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix input tensor compatibility checks in AutoTuner (#588)



* [Refactor] Remove cache existence check in kernel saving logic

- Eliminated redundant checks for existing cache paths in `AutotuneResult` and `AutoTunerCache` classes, simplifying the kernel saving process.
- Ensured that the cache directory is always created before saving kernel source code, improving reliability in kernel storage.

* [Enhancement] Improve input tensor compatibility checks in AutoTuner

- Enhanced the input tensor caching logic in the AutoTuner class to ensure compatibility between cached tensors and newly generated tensors during configuration trials.
- Added detailed logging to warn users about potential mismatches in tensor properties, including shape and dtype, when caching is enabled.
- Implemented a mechanism to regenerate input tensors if compatibility issues are detected, improving the robustness of the autotuning process.

* [Refactor] Update L2 persistent map initialization in CUDA wrapper

- Adjusted the L2 persistent map initialization function to use a consistent size parameter for cache limits and byte counts, improving clarity and reducing potential errors in memory management.
- Simplified the formatting of the initialization function to enhance readability and maintainability of the code.

* Update tilelang/autotuner/__init__.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent c15d35e4
......@@ -6,7 +6,7 @@ and performance optimization through configuration search.
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
......@@ -323,18 +323,38 @@ class AutoTuner:
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
if supply_prog is not None:
logger.warning(
"Incompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
self.jit_input_tensors = jit_input_tensors_supply(
) if self.jit_input_tensors is None else self.jit_input_tensors
params = profiler._get_params(with_output=False)
if self.jit_input_tensors is None:
self.jit_input_tensors = jit_input_tensors_supply()
else:
# check if the cached tensors are compatible with the current configuration
assert len(params) == len(
self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
for p, c in zip(params, self.jit_input_tensors):
if not isinstance(c, torch.Tensor):
# skip non-tensor inputs checking
continue
# Check tensor compatibility using generator expression
if len(params) == len(self.jit_input_tensors):
def shape_equal(a, b):
if len(a.shape) != len(b.shape):
return False
return all(a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape))
if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning(
"\nIncompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
# otherwise, regenerate the input tensors for safety
self.jit_input_tensors = jit_input_tensors_supply()
break
else:
self.jit_input_tensors = jit_input_tensors_supply()
......
......@@ -165,10 +165,6 @@ class AutotuneResult:
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
"""
if os.path.exists(cache_path):
logger.info(f"Cache path {cache_path} already exists, skipping saving kernel to disk")
return
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
......
......@@ -240,11 +240,6 @@ class AutoTunerCache:
- params.pkl: The serialized kernel parameters
"""
cache_path = self._get_cache_path(key)
if os.path.exists(cache_path):
self.logger.info(
f"Cache path {cache_path} already exists, skipping saving kernel to disk")
return
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
......
......@@ -67,9 +67,9 @@ L2_PERSISTENT_MAP_INIT_FUNC = """
\tstream_attribute.accessPolicyWindow.hitRatio = {1};
\tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
\tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {3});
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {2});
\tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
\tstream_attribute.accessPolicyWindow.num_bytes = {3};
\tstream_attribute.accessPolicyWindow.num_bytes = {2};
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
"""
......@@ -359,9 +359,8 @@ class TLCUDASourceWrapper(object):
except Exception:
# as size_in_bytes maybe a symbolic expression
num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), size_in_bytes, num_bytes)
buffer_name, float(hit_ratio), pythonic_expr(num_bytes))
return init_l2_persistent_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment