Unverified Commit 6610c7b9 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

fix: NVRTC backend (#717)



* fix: NVRTC backend

* fix: CI

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 084ab9ee
...@@ -235,7 +235,7 @@ def jit( # This is the new public interface ...@@ -235,7 +235,7 @@ def jit( # This is the new public interface
out_idx: Any = None, out_idx: Any = None,
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None, debug_root_path: Optional[str] = None,
......
...@@ -181,10 +181,10 @@ class PyLibraryGenerator(LibraryGenerator): ...@@ -181,10 +181,10 @@ class PyLibraryGenerator(LibraryGenerator):
culib = None culib = None
pymodule = None pymodule = None
def __init__(self, target: Target): def __init__(self, target: Target, verbose: bool = False):
if not is_nvrtc_available: if not is_nvrtc_available:
raise ImportError(NVRTC_UNAVAILABLE_WARNING) raise ImportError(NVRTC_UNAVAILABLE_WARNING)
super().__init__(target) super().__init__(target, verbose)
@staticmethod @staticmethod
def import_from_file(module_name, file_path): def import_from_file(module_name, file_path):
......
...@@ -81,7 +81,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -81,7 +81,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_device_module(device_mod) self.wrapper.assign_device_module(device_mod)
self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source) self.host_func, self.function_names = self.wrapper.wrap(kernel_global_source)
self.lib_generator = PyLibraryGenerator(self.target) self.lib_generator = PyLibraryGenerator(self.target, self.verbose)
self.lib_generator.update_lib_code(self.kernel_global_source) self.lib_generator.update_lib_code(self.kernel_global_source)
self.lib_generator.update_host_func(self.host_func) self.lib_generator.update_host_func(self.host_func)
self.lib_generator.assign_compile_flags(compile_flags) self.lib_generator.assign_compile_flags(compile_flags)
...@@ -105,7 +105,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -105,7 +105,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -135,7 +136,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -135,7 +136,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target)) adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose adapter.verbose = verbose
adapter.lib_generator = PyLibraryGenerator(adapter.target) adapter.lib_generator = PyLibraryGenerator(adapter.target, adapter.verbose)
adapter.lib_generator.assign_compile_flags(compile_flags)
adapter.lib_generator.load_lib(lib_path=kernel_lib_path) adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.pymodule = adapter.lib_generator.pymodule adapter.pymodule = adapter.lib_generator.pymodule
adapter.function_names = adapter.pymodule._function_names adapter.function_names = adapter.pymodule._function_names
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment