Commit 01526645 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Improve device handling in Cython kernel adapter (#220)

* [Enhancement] Improve device handling in Cython kernel adapter and wrapper

- Updated `CythonKernelAdapter` to support dynamic device assignment based on target type (CUDA, HIP, or CPU).
- Enhanced `CythonKernelWrapper` to include device management, ensuring tensors are allocated on the correct device.
- Added error handling for unsupported target types to improve robustness.

* [Enhancement] Add buffer device mapping in Cython kernel adapter and wrapper

- Introduced `buffer_device_map` in `CythonKernelAdapter` to associate buffer variables with their respective devices.
- Updated `CythonKernelWrapper` to utilize the new buffer device mapping for device checks during tensor allocation.
- Enhanced error handling for device mismatches to ensure tensors are allocated on the correct device, improving robustness and flexibility in device management.
parent 872f5613
...@@ -9,6 +9,7 @@ from tvm.relay import TensorType ...@@ -9,6 +9,7 @@ from tvm.relay import TensorType
from tvm import tir from tvm import tir
from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -128,7 +129,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -128,7 +129,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
""" """
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target: str = "cuda" target: Union[str, Target] = "cuda"
ir_module: Optional[tvm.IRModule] = None ir_module: Optional[tvm.IRModule] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
...@@ -141,6 +142,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -141,6 +142,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16) # "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# } # }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
# Pass configs for the compiler # Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None pass_configs: Optional[Dict[str, Any]] = None
...@@ -148,7 +151,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -148,7 +151,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
rt_mod, rt_mod,
params: List[TensorType], params: List[TensorType],
result_idx: List[int], result_idx: List[int],
target, target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
...@@ -171,11 +174,13 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -171,11 +174,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
else: else:
self.ir_module = func_or_mod self.ir_module = func_or_mod
self.target = Target.canon_target(determine_target(target))
self.dynamic_symbolic_map = self._process_dynamic_symbolic() self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype() self.buffer_dtype_map = self._process_buffer_dtype()
self.static_shape_map = self._process_static_shape() self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose self.verbose = verbose
self.wrapper = TLWrapper(self.target) self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target) self.lib_generator = LibraryGenerator(self.target)
...@@ -198,7 +203,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -198,7 +203,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map) self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map) self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map) self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self._post_init() self._post_init()
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
...@@ -256,6 +261,30 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -256,6 +261,30 @@ class CythonKernelAdapter(BaseKernelAdapter):
static_shape_map[name] = (i, static_shape) static_shape_map[name] = (i, static_shape)
return static_shape_map return static_shape_map
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
buffer_device_map = {}
device = None
if is_cuda_target(self.target) or is_hip_target(self.target):
device = torch.device("cuda")
elif is_cpu_target(self.target):
device = torch.device("cpu")
else:
raise ValueError(f"Unsupported target: {self.target}")
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name = buffer.name
buffer_device_map[name] = (i, device)
return buffer_device_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
......
...@@ -12,6 +12,7 @@ cdef class CythonKernelWrapper: ...@@ -12,6 +12,7 @@ cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference # Class attributes to store kernel configuration and library reference
cdef: cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes object static_shape_map # Maps buffer variables to their corresponding static shapes
list result_idx # Indices of output tensors in the params list list result_idx # Indices of output tensors in the params list
...@@ -53,6 +54,10 @@ cdef class CythonKernelWrapper: ...@@ -53,6 +54,10 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map self.static_shape_map = static_shape_map
return self return self
def set_buffer_device_map(self, buffer_device_map):
self.buffer_device_map = buffer_device_map
return self
cpdef forward(self, list inputs, int64_t stream = -1): cpdef forward(self, list inputs, int64_t stream = -1):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params) cdef int total_params = len(self.params)
...@@ -103,6 +108,14 @@ cdef class CythonKernelWrapper: ...@@ -103,6 +108,14 @@ cdef class CythonKernelWrapper:
else: else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}") raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
# Check buffer device
for param, (buffer_idx, device) in self.buffer_device_map.items():
tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects
if (tensor_device.type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
# Check buffer dtype map # Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items(): for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if tensor_list[buffer_idx].dtype != torch_dtype: if tensor_list[buffer_idx].dtype != torch_dtype:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment