"docs/vscode:/vscode.git/clone" did not exist on "388829db6007f65d7deadb3a3039f46f973b76da"
Unverified Commit f737fa97 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Include compile flags into the hash key of cached kernels (#911)

* [Cache] Add compile_flags parameter to KernelCache hash keys

* [Cache] Update compile_flags parameter to accept both List[str] and str types

* lint

* [Refactor] Update compile_flags parameter to accept Union[List[str], str] type
parent a35ac496
...@@ -20,7 +20,7 @@ def cached( ...@@ -20,7 +20,7 @@ def cached(
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython", execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None, pass_configs: Optional[dict] = None,
compile_flags: Optional[List[str]] = None, compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels (using KernelCache class). Caches and reuses compiled kernels (using KernelCache class).
......
...@@ -73,6 +73,7 @@ class KernelCache: ...@@ -73,6 +73,7 @@ class KernelCache:
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
) -> str: ) -> str:
""" """
Generates a unique hash key for caching compiled kernels. Generates a unique hash key for caching compiled kernels.
...@@ -101,6 +102,7 @@ class KernelCache: ...@@ -101,6 +102,7 @@ class KernelCache:
"target_host": str(target_host) if target_host else None, "target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend, "execution_backend": execution_backend,
"pass_configs": pass_configs, "pass_configs": pass_configs,
"compile_flags": compile_flags,
} }
# Sort keys to ensure consistency # Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True) key_string = json.dumps(key_data, sort_keys=True)
...@@ -117,7 +119,7 @@ class KernelCache: ...@@ -117,7 +119,7 @@ class KernelCache:
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[List[str]] = None, compile_flags: Optional[Union[List[str], str]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels to avoid redundant compilation. Caches and reuses compiled kernels to avoid redundant compilation.
...@@ -152,6 +154,7 @@ class KernelCache: ...@@ -152,6 +154,7 @@ class KernelCache:
target=target, target=target,
target_host=target_host, target_host=target_host,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
with self._lock: with self._lock:
# First check in-memory cache # First check in-memory cache
...@@ -165,7 +168,8 @@ class KernelCache: ...@@ -165,7 +168,8 @@ class KernelCache:
# Then check disk cache # Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func, verbose) execution_backend, pass_configs, compile_flags,
func, verbose)
if kernel is not None: if kernel is not None:
if verbose: if verbose:
self.logger.debug( self.logger.debug(
...@@ -185,6 +189,7 @@ class KernelCache: ...@@ -185,6 +189,7 @@ class KernelCache:
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
if execution_backend == "dlpack": if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.") self.logger.warning("DLPack backend does not support cache saving to disk.")
...@@ -322,6 +327,7 @@ class KernelCache: ...@@ -322,6 +327,7 @@ class KernelCache:
out_idx: List[int] = None, out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
func: Callable = None, func: Callable = None,
verbose: bool = False, verbose: bool = False,
) -> Optional[JITKernel]: ) -> Optional[JITKernel]:
...@@ -382,6 +388,7 @@ class KernelCache: ...@@ -382,6 +388,7 @@ class KernelCache:
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
else: else:
return None return None
......
...@@ -95,7 +95,7 @@ class _JitImplementation: ...@@ -95,7 +95,7 @@ class _JitImplementation:
verbose: bool verbose: bool
pass_configs: Optional[Dict[str, Any]] pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str] debug_root_path: Optional[str]
compile_flags: Optional[List[str]] compile_flags: Optional[Union[List[str], str]]
def __init__(self, def __init__(self,
out_idx: Any = None, out_idx: Any = None,
...@@ -105,7 +105,7 @@ class _JitImplementation: ...@@ -105,7 +105,7 @@ class _JitImplementation:
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None, debug_root_path: Optional[str] = None,
compile_flags: Optional[List[str]] = None): compile_flags: Optional[Union[List[str], str]] = None):
""" """
Initializes the JIT compiler decorator. Initializes the JIT compiler decorator.
...@@ -137,6 +137,9 @@ class _JitImplementation: ...@@ -137,6 +137,9 @@ class _JitImplementation:
If None, no debug information is saved (default: None). If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root If a relative path is given, it's made absolute relative to the project root
or current working directory. or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
""" """
self.out_idx = out_idx self.out_idx = out_idx
self.execution_backend = execution_backend self.execution_backend = execution_backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment