Commit b21f63d9 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Enhancement] Reduce CPU overhead during kernel execution (#437)

* [Enhancement] Reduce CPU overhead during kernel execution

* fix lint
parent dabe6e0a
...@@ -22,7 +22,7 @@ cdef class CythonKernelWrapper: ...@@ -22,7 +22,7 @@ cdef class CythonKernelWrapper:
# Add new cache attributes # Add new cache attributes
list param_dtypes # Cache for parameter dtypes list param_dtypes # Cache for parameter dtypes
list param_shapes # Cache for parameter shapes as native Python lists list param_shapes # Cache for parameter shapes as native Python lists
object get_current_device
def __cinit__(self, result_idx, params, lib): def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration # Initialize wrapper with kernel configuration
self.result_idx = result_idx self.result_idx = result_idx
...@@ -32,6 +32,7 @@ cdef class CythonKernelWrapper: ...@@ -32,6 +32,7 @@ cdef class CythonKernelWrapper:
self.param_dtypes = [param.dtype for param in params] self.param_dtypes = [param.dtype for param in params]
# Convert TVM shape arrays to native Python lists # Convert TVM shape arrays to native Python lists
self.param_shapes = [] self.param_shapes = []
self.get_current_device = torch.cuda.current_device
for param in params: for param in params:
native_shape = [] native_shape = []
for dim in param.shape: for dim in param.shape:
...@@ -79,7 +80,10 @@ cdef class CythonKernelWrapper: ...@@ -79,7 +80,10 @@ cdef class CythonKernelWrapper:
# Use current CUDA stream if none specified # Use current CUDA stream if none specified
if stream == -1: if stream == -1:
if torch.cuda.is_available(): if torch.cuda.is_available():
stream = torch.cuda.current_stream().cuda_stream try:
stream = torch._C._cuda_getCurrentRawStream(torch.cuda.current_device())
except ImportError:
stream = torch.cuda.current_stream().cuda_stream
else: else:
stream = 0 stream = 0
...@@ -126,11 +130,14 @@ cdef class CythonKernelWrapper: ...@@ -126,11 +130,14 @@ cdef class CythonKernelWrapper:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}") raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
# Check buffer device # Check buffer device
# cdef str tensor_list_device_type = tensor_list[0].device.type
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items(): for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor): if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects # Compare device types and indices separately to handle both string and torch.device objects
if (tensor_device.type != device.type or if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)): (tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}") raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment