Commit 8c94de32 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Refactor matrix multiplication benchmark and autotuner logging (#263)

- Updated `ref_program` in `benchmark_matmul.py` to remove the unused parameter `C`, simplifying the function signature.
- Changed logging level in `autotuner/__init__.py` from `INFO` to `DEBUG` for more detailed logging during autotuning.
- Modified the error handling in the autotuner to provide clearer messages and log errors at the debug level.
- Enhanced error reporting in the JIT adapter by adding detailed context to error messages in `cython_wrapper.pyx` when kernel calls fail.
parent 927e50d9
......@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B, C):
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
......@@ -289,8 +289,7 @@ if __name__ == "__main__":
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
# with_roller = args.with_roller
with_roller = True
with_roller = args.with_roller
# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K
......
......@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
logging.basicConfig(
filename='autotuner.log',
filemode='w',
level=logging.INFO,
level=logging.DEBUG,
format='%(asctime)s %(levelname)s:%(message)s')
......@@ -149,15 +149,14 @@ class Autotuner:
for i in progress_bar:
jit_context, config = results_with_configs[i]
try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(target_fn, jit_context)
latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError:
logger.debug(f"Timeout exceeded for config {config}. Skipping this configuration.")
continue
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread
latency, ref_latency = target_fn(jit_context)
except Exception as e:
logger.debug(f"An error occurred while testing config {config}: {e}")
logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {e}")
continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
......
......@@ -141,7 +141,10 @@ cdef class CythonKernelWrapper:
call_args.append(ctypes.c_void_p(stream))
# Execute the kernel
self.lib.call(*call_args)
result = self.lib.call(*call_args)
if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8')
raise RuntimeError(f"Kernel call failed: {error_msg}")
# Return output tensor(s)
if len(self.result_idx) == 1:
......
......@@ -32,8 +32,9 @@ extern "C" int init() {{
"""
PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
extern "C" int call({}) {{
{}
return 0;
}}
"""
......@@ -50,12 +51,28 @@ TMA_DESC_INIT_FUNC = """
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9};
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10};
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11};
\tCUresult {0}_result = cuTensorMapEncodeTiled(
\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\tif ({0}_result != CUDA_SUCCESS) {{
\t\tprintf("Failed to initialize the TMA descriptor {0} with error code %d\\n", {0}_result);
\t\texit(-1);
\t}}
std::stringstream ss;
ss << "TMA Desc Addr: " << &{0}
<< "\\nformat " << {0}_type
<< "\\ndim " << {0}_tensorRank
<< "\\ngmem_address " << {0}_globalAddress
<< "\\nglobalDim " << {0}_globalDim
<< "\\nglobalStrides " << {0}_globalStride + 1
<< "\\nboxDim " << {0}_boxDim
<< "\\nelementStrides " << {0}_elementStrides
<< "\\ninterleave " << {0}_interleave
<< "\\nswizzle " << {0}_swizzle
<< "\\nl2Promotion " << {0}_l2Promotion
<< "\\noobFill " << {0}_oobFill
<< "\\nError: Failed to initialize the TMA descriptor {0}";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1;
}}
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment