You need to sign in or sign up before continuing.
Commit 8c94de32 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Refactor matrix multiplication benchmark and autotuner logging (#263)

- Updated `ref_program` in `benchmark_matmul.py` to remove the unused parameter `C`, simplifying the function signature.
- Changed logging level in `autotuner/__init__.py` from `INFO` to `DEBUG` for more detailed logging during autotuning.
- Modified the error handling in the autotuner to provide clearer messages and log errors at the debug level.
- Enhanced error reporting in the JIT adapter by adding detailed context to error messages in `cython_wrapper.pyx` when kernel calls fail.
parent 927e50d9
...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) ...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def ref_program(A, B, C): def ref_program(A, B):
""" """
A reference matrix multiplication program, used to compare performance. A reference matrix multiplication program, used to compare performance.
...@@ -289,8 +289,7 @@ if __name__ == "__main__": ...@@ -289,8 +289,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
# with_roller = args.with_roller with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations to measure throughput # Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
......
...@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) ...@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
logging.basicConfig( logging.basicConfig(
filename='autotuner.log', filename='autotuner.log',
filemode='w', filemode='w',
level=logging.INFO, level=logging.DEBUG,
format='%(asctime)s %(levelname)s:%(message)s') format='%(asctime)s %(levelname)s:%(message)s')
...@@ -149,15 +149,14 @@ class Autotuner: ...@@ -149,15 +149,14 @@ class Autotuner:
for i in progress_bar: for i in progress_bar:
jit_context, config = results_with_configs[i] jit_context, config = results_with_configs[i]
try: try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: # Because tma init may behave strangely with one thread
future = executor.submit(target_fn, jit_context) latency, ref_latency = target_fn(jit_context)
latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError:
logger.debug(f"Timeout exceeded for config {config}. Skipping this configuration.")
continue
except Exception as e: except Exception as e:
logger.debug(f"An error occurred while testing config {config}: {e}") logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {e}")
continue continue
logging.debug(f"Config {config} latency: {latency} at index {i}") logging.debug(f"Config {config} latency: {latency} at index {i}")
......
...@@ -141,7 +141,10 @@ cdef class CythonKernelWrapper: ...@@ -141,7 +141,10 @@ cdef class CythonKernelWrapper:
call_args.append(ctypes.c_void_p(stream)) call_args.append(ctypes.c_void_p(stream))
# Execute the kernel # Execute the kernel
self.lib.call(*call_args) result = self.lib.call(*call_args)
if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8')
raise RuntimeError(f"Kernel call failed: {error_msg}")
# Return output tensor(s) # Return output tensor(s)
if len(self.result_idx) == 1: if len(self.result_idx) == 1:
......
...@@ -32,8 +32,9 @@ extern "C" int init() {{ ...@@ -32,8 +32,9 @@ extern "C" int init() {{
""" """
PREDEF_HOST_FUNC = """ PREDEF_HOST_FUNC = """
extern "C" void call({}) {{ extern "C" int call({}) {{
{} {}
return 0;
}} }}
""" """
...@@ -50,12 +51,28 @@ TMA_DESC_INIT_FUNC = """ ...@@ -50,12 +51,28 @@ TMA_DESC_INIT_FUNC = """
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9}; \tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9};
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10}; \tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10};
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11}; \tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11};
\tCUresult {0}_result = cuTensorMapEncodeTiled(
\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill); &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\tif ({0}_result != CUDA_SUCCESS) {{ \tif ({0}_result != CUDA_SUCCESS) {{
\t\tprintf("Failed to initialize the TMA descriptor {0} with error code %d\\n", {0}_result); std::stringstream ss;
\t\texit(-1); ss << "TMA Desc Addr: " << &{0}
\t}} << "\\nformat " << {0}_type
<< "\\ndim " << {0}_tensorRank
<< "\\ngmem_address " << {0}_globalAddress
<< "\\nglobalDim " << {0}_globalDim
<< "\\nglobalStrides " << {0}_globalStride + 1
<< "\\nboxDim " << {0}_boxDim
<< "\\nelementStrides " << {0}_elementStrides
<< "\\ninterleave " << {0}_interleave
<< "\\nswizzle " << {0}_swizzle
<< "\\nl2Promotion " << {0}_l2Promotion
<< "\\noobFill " << {0}_oobFill
<< "\\nError: Failed to initialize the TMA descriptor {0}";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}}
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment