Commit 01526645 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Improve device handling in Cython kernel adapter (#220)

* [Enhancement] Improve device handling in Cython kernel adapter and wrapper

- Updated `CythonKernelAdapter` to support dynamic device assignment based on target type (CUDA, HIP, or CPU).
- Enhanced `CythonKernelWrapper` to include device management, ensuring tensors are allocated on the correct device.
- Added error handling for unsupported target types to improve robustness.

* [Enhancement] Add buffer device mapping in Cython kernel adapter and wrapper

- Introduced `buffer_device_map` in `CythonKernelAdapter` to associate buffer variables with their respective devices.
- Updated `CythonKernelWrapper` to utilize the new buffer device mapping for device checks during tensor allocation.
- Enhanced error handling for device mismatches to ensure tensors are allocated on the correct device, improving robustness and flexibility in device management.
parent 872f5613
......@@ -9,6 +9,7 @@ from tvm.relay import TensorType
from tvm import tir
from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
......@@ -128,7 +129,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
"""
# Class attributes to store compiled kernel information
target: str = "cuda"
target: Union[str, Target] = "cuda"
ir_module: Optional[tvm.IRModule] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
......@@ -141,6 +142,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
......@@ -148,7 +151,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
rt_mod,
params: List[TensorType],
result_idx: List[int],
target,
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
......@@ -171,11 +174,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
else:
self.ir_module = func_or_mod
self.target = Target.canon_target(determine_target(target))
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype()
self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device()
self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
......@@ -198,7 +203,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self._post_init()
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
......@@ -256,6 +261,30 @@ class CythonKernelAdapter(BaseKernelAdapter):
static_shape_map[name] = (i, static_shape)
return static_shape_map
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
buffer_device_map = {}
device = None
if is_cuda_target(self.target) or is_hip_target(self.target):
device = torch.device("cuda")
elif is_cpu_target(self.target):
device = torch.device("cpu")
else:
raise ValueError(f"Unsupported target: {self.target}")
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name = buffer.name
buffer_device_map[name] = (i, device)
return buffer_device_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel.
......
......@@ -12,6 +12,7 @@ cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference
cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
list result_idx # Indices of output tensors in the params list
......@@ -53,6 +54,10 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map
return self
def set_buffer_device_map(self, buffer_device_map):
self.buffer_device_map = buffer_device_map
return self
cpdef forward(self, list inputs, int64_t stream = -1):
# Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params)
......@@ -103,6 +108,14 @@ cdef class CythonKernelWrapper:
else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
# Check buffer device
for param, (buffer_idx, device) in self.buffer_device_map.items():
tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects
if (tensor_device.type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if tensor_list[buffer_idx].dtype != torch_dtype:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment