Commit 0a8c8b99 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Add TILELANG_CHECK_LAST_ERROR macro for improved error handling in CUDA and HIP (#450)

* [Feature] Add TILELANG_CHECK_LAST_ERROR macro for improved error handling in CUDA and HIP

* Introduced TILELANG_CHECK_LAST_ERROR macro to streamline error checking for kernel launches in both CUDA and HIP.
* Updated kernel launch code in wrapper.py to utilize the new macro, enhancing readability and maintainability.
* This change improves error reporting by providing detailed messages when kernel execution fails.

* [Refactor] Standardize error message formatting in TILELANG_CHECK_LAST_ERROR macro

* Updated the TILELANG_CHECK_LAST_ERROR macro in both CUDA and HIP implementations to ensure consistent formatting of error messages.
* Enhanced readability by aligning the error message structure across different platforms, improving maintainability of error handling code.
parent 025929d8
......@@ -35,6 +35,16 @@ using int4_t = int4;
} \
} while (0)
#define TILELANG_CHECK_LAST_ERROR(kernel_name) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \
cudaGetErrorName(__err), cudaGetErrorString(__err)); \
return -1; \
} \
} while (0)
// abs function for bfloat_t and half_t since there is no implicit convertion
// method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
......
......@@ -36,6 +36,16 @@
} \
} while (0)
#define TILELANG_CHECK_LAST_ERROR(kernel_name) \
do { \
hipError_t __err = hipGetLastError(); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \
hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define half _Float16
#define __float2half_rn(x) half(x)
......
......@@ -220,11 +220,7 @@ class TLCUDASourceWrapper(object):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n"
kernel_launch_code += "\tif (err != cudaSuccess) {{\n"
kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n"
kernel_launch_code += "\t\treturn -1;\n"
kernel_launch_code += "\t}}\n"
kernel_launch_code += "TILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment