Unverified Commit 394e17d0 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[Refactor] Refine nvrtc compile related check style (#945)

* unify nvrtc check style

* unify nvrtc check style

* unify nvrtc check style
parent c61971e8
......@@ -20,16 +20,13 @@ from .utils import is_cpu_target, is_cuda_target, is_hip_target
logger = logging.getLogger(__name__)
is_nvrtc_available = False
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try:
from tilelang.jit.adapter.nvrtc import is_nvrtc_available
if is_nvrtc_available:
import cuda.bindings.driver as cuda
from tilelang.contrib.nvrtc import compile_cuda
is_nvrtc_available = True
except ImportError:
pass
is_nvrtc_available = False
class LibraryGenerator(object):
......@@ -194,7 +191,9 @@ class PyLibraryGenerator(LibraryGenerator):
def __init__(self, target: Target, verbose: bool = False):
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
raise ImportError("cuda-python is not available, nvrtc backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the nvrtc backend.")
super().__init__(target, verbose)
@staticmethod
......@@ -243,7 +242,7 @@ class PyLibraryGenerator(LibraryGenerator):
else:
tl_template_path = TILELANG_TEMPLATE_PATH
cuda_home = "/usr/local/cuda" if CUDA_HOME is None else CUDA_HOME
cuda_home = CUDA_HOME if CUDA_HOME else "/usr/local/cuda"
options = [f"-I{tl_template_path}", f"-I{cutlass_path}", f"-I{cuda_home}/include"]
if self.compile_flags:
......
from .adapter import NVRTCKernelAdapter # noqa: F401
"""NVRTC Backend for TileLang.
This module provides runtime compilation support using NVIDIA's NVRTC API.
"""
import logging
__all__ = ['NVRTCKernelAdapter', 'is_nvrtc_available', 'check_nvrtc_available']
logger = logging.getLogger(__name__)
# Check if cuda-python is available
is_nvrtc_available = False
NVRTC_UNAVAILABLE_MESSAGE = ("cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend.")
try:
import cuda.bindings.driver as cuda # noqa: F401
import cuda.bindings.nvrtc as nvrtc # noqa: F401
is_nvrtc_available = True
except ImportError as e:
logger.debug(f"cuda-python import failed: {e}")
def check_nvrtc_available():
"""Check if NVRTC backend is available.
Raises
------
ImportError
If cuda-python is not installed or cannot be imported
"""
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
# Conditionally import the adapter
if is_nvrtc_available:
from .adapter import NVRTCKernelAdapter # noqa: F401
else:
# Provide a dummy class that raises error on instantiation
class NVRTCKernelAdapter:
"""Dummy NVRTCKernelAdapter that raises ImportError on instantiation."""
def __init__(self, *args, **kwargs):
raise ImportError(NVRTC_UNAVAILABLE_MESSAGE)
import logging
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from tvm import tir
......@@ -11,20 +11,14 @@ from tilelang.jit.adapter.wrapper import TLPyWrapper
from tilelang.jit.adapter.libgen import PyLibraryGenerator
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.target import determine_target
from ..base import BaseKernelAdapter
from tilelang.jit.adapter.base import BaseKernelAdapter
from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available
logger = logging.getLogger(__name__)
is_nvrtc_available = False
NVRTC_UNAVAILABLE_WARNING = "cuda-python is not available, nvrtc backend cannot be used. " \
"Please install cuda-python via `pip install cuda-python` " \
"if you want to use the nvrtc backend."
try:
# Import cuda bindings if available
if is_nvrtc_available:
import cuda.bindings.driver as cuda
is_nvrtc_available = True
except ImportError:
pass
class NVRTCKernelAdapter(BaseKernelAdapter):
......@@ -43,8 +37,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING)
check_nvrtc_available()
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
......@@ -150,11 +143,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self):
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
Returns
-------
Dict[tir.Var, Tuple[int, int]]
Mapping from symbolic variable to (buffer_index, shape_dimension)
"""
func = self.prim_func
params = func.params
......@@ -167,7 +165,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
def get_kernel_source(self):
def get_kernel_source(self) -> Optional[str]:
"""Get the CUDA kernel source code.
Returns
-------
Optional[str]
The kernel source code, or None if not available
"""
return self.kernel_global_source
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
......@@ -237,7 +242,14 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
else:
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable:
def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]:
"""Convert to a PyTorch-compatible function.
Returns
-------
Callable[..., Union[torch.Tensor, List[torch.Tensor]]]
A callable function that takes tensors and returns tensor(s)
"""
return self._wrap_forward_from_prebuild_lib
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment