Unverified Commit 551ac60d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283)

- Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options.
- Introduced handling for fast math and PTXAS options based on the provided pass configuration.
- Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings.
- Refactored NVCC command construction to use a dedicated function for better clarity and maintainability.
parent cd681e63
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "runtime/cuda/cuda_module.h" #include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h" #include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -66,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -66,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
std::string ptx; std::string ptx;
if (const auto f = if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) { ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>(); // Fetch current pass context config and pass into the compile callback
tvm::transform::PassContext pass_ctx =
tvm::transform::PassContext::Current();
ptx = (*f)(code, target, pass_ctx->config).cast<std::string>();
if (ptx[0] != '/') if (ptx[0] != '/')
fmt = "cubin"; fmt = "cubin";
} else { } else {
......
...@@ -78,7 +78,7 @@ def compile_cuda(code, ...@@ -78,7 +78,7 @@ def compile_cuda(code,
out_file.write(code) out_file.write(code)
file_target = path_target if path_target else temp_target file_target = path_target if path_target else temp_target
cmd = ["nvcc"] cmd = [get_nvcc_compiler()]
cmd += [f"--{target_format}", "-O3"] cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None: if kernels_output_dir is not None:
cmd += ["-lineinfo"] cmd += ["-lineinfo"]
...@@ -332,13 +332,6 @@ def get_cuda_version(cuda_path=None): ...@@ -332,13 +332,6 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file") raise RuntimeError("Cannot read cuda version file")
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True) @tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch): def find_libdevice_path(arch):
"""Utility function to find libdevice """Utility function to find libdevice
......
...@@ -11,6 +11,8 @@ import tvm_ffi ...@@ -11,6 +11,8 @@ import tvm_ffi
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
from tilelang.transform import PassConfigKey
from tilelang.utils.deprecated import deprecated_warning
from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.engine.phase import ( from tilelang.engine.phase import (
...@@ -54,7 +56,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: ...@@ -54,7 +56,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): def tilelang_callback_cuda_compile(code, target, pass_config=None):
project_root = osp.join(osp.dirname(__file__), "../..") project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ: if "TL_TEMPLATE_PATH" in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"] tl_template_path = os.environ["TL_TEMPLATE_PATH"]
...@@ -69,21 +71,37 @@ def tilelang_callback_cuda_compile(code, target): ...@@ -69,21 +71,37 @@ def tilelang_callback_cuda_compile(code, target):
target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
arch = [f"-arch=sm_{target_arch}"] arch = [f"-arch=sm_{target_arch}"]
format = "cubin" compile_format = "cubin"
# Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib
cfg = pass_config or {}
if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, False):
deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7")
disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, True))
enable_fast_math = not disable_fast_math
else:
enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH.value, False))
# printing out number of registers ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL.value, None)
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage" verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT.value, False))
ptx = nvcc.compile_cuda(
code, options = [
format,
arch,
options=[
"-std=c++17", "-std=c++17",
debug_option,
"--use_fast_math",
"-I" + tl_template_path, "-I" + tl_template_path,
"-I" + cutlass_path, "-I" + cutlass_path,
], ]
if enable_fast_math:
options.append("--use_fast_math")
if ptxas_usage_level is not None:
options.append(f"--ptxas-options=--register-usage-level={ptxas_usage_level}")
if verbose_ptxas_output:
options.append("--ptxas-options=--verbose")
ptx = nvcc.compile_cuda(
code,
compile_format,
arch,
options=options,
verbose=False, verbose=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment