Unverified Commit 394e17d0 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Refactor] Refine nvrtc compile related check style (#945)

* unify nvrtc check style

* unify nvrtc check style

* unify nvrtc check style
parent c61971e8
...@@ -20,16 +20,13 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target ...@@ -20,16 +20,13 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
is_nvrtc_available = False
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try: try:
import cuda.bindings.driver as cuda from tilelang.jit.adapter.nvrtc import is_nvrtc_available
from tilelang.contrib.nvrtc import compile_cuda if is_nvrtc_available:
is_nvrtc_available = True import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
except ImportError: except ImportError:
pass is_nvrtc_available = False
class LibraryGenerator(object): class LibraryGenerator(object):
...@@ -194,7 +191,9 @@ class PyLibraryGenerator(LibraryGenerator): ...@@ -194,7 +191,9 @@ class PyLibraryGenerator(LibraryGenerator):
def __init__(self, target: Target, verbose: bool = False): def __init__(self, target: Target, verbose: bool = False):
if not is_nvrtc_available: if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING) raise ImportError("cuda-python is not available, nvrtc backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the nvrtc backend.")
super().__init__(target, verbose) super().__init__(target, verbose)
@staticmethod @staticmethod
...@@ -243,7 +242,7 @@ class PyLibraryGenerator(LibraryGenerator): ...@@ -243,7 +242,7 @@ class PyLibraryGenerator(LibraryGenerator):
else: else:
tl_template_path = TILELANG_TEMPLATE_PATH tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"] options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags: if self.compile_flags:
......
from .adapter import NVRTCKernelAdapter # noqa: F401 """NVRTC Backend for TileLang.
This module provides runtime compilation support using NVIDIA's NVRTC API.
"""
import logging
__all__ = ['NVRTCKernelAdapter', 'is_nvrtc_available', 'check_nvrtc_available']
logger = logging.getLogger(__name__)
# Check if cuda-python is available
is_nvrtc_available = False
NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend.")
try:
import cuda.bindings.driver as cuda # noqa: F401
import cuda.bindings.nvrtc as nvrtc # noqa: F401
is_nvrtc_available = True
except ImportError as e:
logger.debug(f"cuda-python import failed: {e}")
def check_nvrtc_available():
"""Check if NVRTC backend is available.
Raises
------
ImportError
If cuda-python is not installed or cannot be imported
"""
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
# Conditionally import the adapter
if is_nvrtc_available:
from .adapter import NVRTCKernelAdapter # noqa: F401
else:
# Provide a dummy class that raises error on instantiation
class NVRTCKernelAdapter:
"""Dummy NVRTCKernelAdapter that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from tvm import tir from tvm import tir
...@@ -11,20 +11,14 @@ from tilelang.jit.adapter.wrapper import TLPyWrapper ...@@ -11,20 +11,14 @@ from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.libgen import PyLibraryGenerator from tilelang.jit.adapter.libgen import PyLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.adapter.base import BaseKernelAdapter
from ..base import BaseKernelAdapter from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
is_nvrtc_available = False # Import cuda bindings if available
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \ if is_nvrtc_available:
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try:
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
is_nvrtc_available = True
except ImportError:
pass
class NVRTCKernelAdapter(BaseKernelAdapter): class NVRTCKernelAdapter(BaseKernelAdapter):
...@@ -43,8 +37,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -43,8 +37,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None): compile_flags: Optional[List[str]] = None):
if not is_nvrtc_available: check_nvrtc_available()
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
...@@ -150,11 +143,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -150,11 +143,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self): def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution. for runtime shape resolution.
Returns
-------
Dict[tir.Var, Tuple[int, int]]
Mapping from symbolic variable to (buffer_index, shape_dimension)
""" """
func = self.prim_func func = self.prim_func
params = func.params params = func.params
...@@ -167,7 +165,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -167,7 +165,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def get_kernel_source(self): def get_kernel_source(self) -> Optional[str]:
"""Get the CUDA kernel source code.
Returns
-------
Optional[str]
The kernel source code, or None if not available
"""
return self.kernel_global_source return self.kernel_global_source
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
...@@ -237,7 +242,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -237,7 +242,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
else: else:
return [args[i] for i in self.result_idx] return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable: def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]:
"""Convert to a PyTorch-compatible function.
Returns
-------
Callable[..., Union[torch.Tensor, List[torch.Tensor]]]
A callable function that takes tensors and returns tensor(s)
"""
return self._wrap_forward_from_prebuild_lib return self._wrap_forward_from_prebuild_lib
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment