Unverified Commit 551ac60d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283)

- Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options.
- Introduced handling for fast math and PTXAS options based on the provided pass configuration.
- Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings.
- Refactored NVCC command construction to use a dedicated function for better clarity and maintainability.
parent cd681e63
......@@ -2,6 +2,7 @@
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
namespace tvm {
namespace codegen {
......@@ -66,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
std::string ptx;
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>();
// Fetch current pass context config and pass into the compile callback
tvm::transform::PassContext pass_ctx =
tvm::transform::PassContext::Current();
ptx = (*f)(code, target, pass_ctx->config).cast<std::string>();
if (ptx[0] != '/')
fmt = "cubin";
} else {
......
......@@ -78,7 +78,7 @@ def compile_cuda(code,
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd = [get_nvcc_compiler()]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
cmd += ["-lineinfo"]
......@@ -332,13 +332,6 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
......
......@@ -11,6 +11,8 @@ import tvm_ffi
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.transform import PassConfigKey
from tilelang.utils.deprecated import deprecated_warning
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target
from tilelang.engine.phase import (
......@@ -54,7 +56,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
def tilelang_callback_cuda_compile(code, target, pass_config=None):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
......@@ -69,21 +71,37 @@ def tilelang_callback_cuda_compile(code, target):
target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target))
arch = [f"-arch=sm_{target_arch}"]
format = "cubin"
compile_format = "cubin"
# Read pass-config keys (string-valued) like in jit.adapter.libgen.compile_lib
cfg = pass_config or {}
if cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, False):
deprecated_warning("TL_DISABLE_FAST_MATH", "TL_ENABLE_FAST_MATH", "0.1.7")
disable_fast_math = bool(cfg.get(PassConfigKey.TL_DISABLE_FAST_MATH.value, True))
enable_fast_math = not disable_fast_math
else:
enable_fast_math = bool(cfg.get(PassConfigKey.TL_ENABLE_FAST_MATH.value, False))
ptxas_usage_level = cfg.get(PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL.value, None)
verbose_ptxas_output = bool(cfg.get(PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT.value, False))
options = [
"-std=c++17",
"-I" + tl_template_path,
"-I" + cutlass_path,
]
if enable_fast_math:
options.append("--use_fast_math")
if ptxas_usage_level is not None:
options.append(f"--ptxas-options=--register-usage-level={ptxas_usage_level}")
if verbose_ptxas_output:
options.append("--ptxas-options=--verbose")
# printing out number of registers
debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
ptx = nvcc.compile_cuda(
code,
format,
compile_format,
arch,
options=[
"-std=c++17",
debug_option,
"--use_fast_math",
"-I" + tl_template_path,
"-I" + cutlass_path,
],
options=options,
verbose=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment