Unverified Commit 7e5b1cd2 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Utils] Add source export, NVCC-based PTX/SASS dump, logging (#1216)

* [Enhancement] Add NVCC support for PTX and SASS generation in TileLang

* Introduced functions to compile CUDA C++ source to PTX and SASS formats, enhancing the ability to generate intermediate representations for CUDA kernels.
* Added default compile options for NVCC, including paths for TileLang templates, CUTLASS, and CUDA includes.
* Implemented methods to export and display generated PTX and SASS code, improving usability for developers working with CUDA targets.
* Updated JITKernel class to integrate new NVCC functionalities for PTX and SASS handling, ensuring compatibility with existing workflows.

* [Fix] Improve error handling in get_sass_from_source function

* Added contextlib to suppress exceptions when removing temporary files, enhancing robustness.
* Fixed formatting of error message for clarity when CUDA tools are not found, ensuring better user feedback.

* [Enhancement] Preserve user flags in NVCC compile options

* Updated the default_compile_options function to preserve user-specified compile flags, including repeated tokens, by utilizing shlex for proper tokenization.
* This enhancement improves the flexibility and accuracy of NVCC compile options, ensuring that all user inputs are correctly handled.
parent 2bc45bc3
...@@ -7,7 +7,10 @@ from __future__ import annotations ...@@ -7,7 +7,10 @@ from __future__ import annotations
import os import os
import subprocess import subprocess
import warnings import warnings
from tilelang.env import CUDA_HOME import contextlib
from tilelang.env import CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH
import shutil
import tempfile
import tvm_ffi import tvm_ffi
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
...@@ -125,6 +128,154 @@ def compile_cuda(code, ...@@ -125,6 +128,154 @@ def compile_cuda(code,
return data return data
def default_compile_options(compile_flags: list[str] | None = None) -> list[str]:
"""
Build a set of default NVCC compile options for TileLang generated sources.
Includes C++ standard and common include paths (TileLang templates, CUTLASS,
CUDA include). Merges user-provided compile flags if given.
Parameters
----------
compile_flags : Optional[List[str]]
Additional flags to include. Items are split on whitespace.
Returns
-------
List[str]
A list of flags suitable for NVCC's command line.
"""
options: list[str] = ["-std=c++17"]
try:
if TILELANG_TEMPLATE_PATH:
options.append(f"-I{TILELANG_TEMPLATE_PATH}")
except Exception:
pass
try:
if CUTLASS_INCLUDE_DIR:
options.append(f"-I{CUTLASS_INCLUDE_DIR}")
except Exception:
pass
try:
if CUDA_HOME:
options.append(f"-I{os.path.join(CUDA_HOME, 'include')}")
except Exception:
pass
# Preserve user flags exactly, including repeated tokens required by NVCC
# (e.g., multiple "-gencode" pairs or repeated "-Xcompiler" entries).
if compile_flags:
import shlex
for flag in compile_flags:
# Split each string like a shell would, preserving quoted args
tokens = shlex.split(flag) if isinstance(flag, str) else [str(flag)]
options.extend(tokens)
return options
def get_ptx_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to PTX using NVCC and return as text.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print NVCC output when True.
Returns
-------
str
PTX text.
"""
opts = default_compile_options(compile_flags)
ptx_bytes = compile_cuda(code, target_format="ptx", options=opts, verbose=verbose)
try:
return ptx_bytes.decode("utf-8")
except Exception:
return str(ptx_bytes)
def _find_tool(name: str) -> str | None:
"""Find a CUDA binary in PATH or under CUDA_HOME/bin."""
path = shutil.which(name)
if path:
return path
if CUDA_HOME:
candidate = os.path.join(CUDA_HOME, "bin", name)
if os.path.exists(candidate):
return candidate
return None
def get_sass_from_source(code: str,
compile_flags: list[str] | None = None,
verbose: bool = False) -> str:
"""
Compile CUDA C++ source to CUBIN and disassemble to SASS.
Uses nvdisasm if available; otherwise falls back to cuobjdump.
Parameters
----------
code : str
CUDA C++ kernel source code.
compile_flags : Optional[List[str]]
Additional flags merged with defaults.
verbose : bool
Print tool outputs when True.
Returns
-------
str
SASS text.
"""
opts = default_compile_options(compile_flags)
cubin_bytes = compile_cuda(code, target_format="cubin", options=opts, verbose=verbose)
# Write to a temp .cubin file
with tempfile.NamedTemporaryFile(suffix=".cubin", delete=False) as tmp:
tmp.write(cubin_bytes)
cubin_path = tmp.name
# Try disassembly tools (prefer nvdisasm, fallback cuobjdump)
cand_nvdisasm = _find_tool("nvdisasm")
cand_cuobjdump = _find_tool("cuobjdump")
if not cand_nvdisasm and not cand_cuobjdump:
raise RuntimeError(
"Cannot find 'nvdisasm' or 'cuobjdump'. Please ensure CUDA toolkit is installed and in PATH."
)
last_err: str | None = None
try:
# Attempt nvdisasm first
tools_to_try = []
if cand_nvdisasm:
tools_to_try.append(("nvdisasm", [cand_nvdisasm, cubin_path]))
if cand_cuobjdump:
tools_to_try.append(("cuobjdump", [cand_cuobjdump, "--dump-sass", cubin_path]))
for tool_name, cmd in tools_to_try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate()
text = py_str(out)
if verbose:
print(f"[{tool_name}] output:\n{text}")
if proc.returncode == 0 and text.strip():
return text
last_err = f"{tool_name} rc={proc.returncode}, output:\n{text}"
# If we reach here, all attempts failed
raise RuntimeError(f"SASS disassembly failed. Tried tools: "
f"{', '.join(name for name, _ in tools_to_try)}\n{last_err or ''}")
finally:
with contextlib.suppress(Exception):
os.remove(cubin_path)
def find_cuda_path(): def find_cuda_path():
"""Utility function to find cuda path """Utility function to find cuda path
......
...@@ -6,7 +6,7 @@ try: ...@@ -6,7 +6,7 @@ try:
except ImportError: # Python < 3.10 except ImportError: # Python < 3.10
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target, is_cuda_target
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -18,7 +18,9 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython ...@@ -18,7 +18,9 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython
NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter) NVRTCKernelAdapter, TorchDLPackKernelAdapter, MetalKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.contrib import nvcc as tl_nvcc
import logging import logging
import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -412,6 +414,110 @@ class JITKernel(Generic[_P, _T]): ...@@ -412,6 +414,110 @@ class JITKernel(Generic[_P, _T]):
def run_once(self, func: Callable | None = None) -> None: def run_once(self, func: Callable | None = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
def show_source(self, which: Literal["kernel", "host", "both"] = "kernel") -> None:
"""
Print generated source code to stdout.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Select which source to print. Defaults to "kernel".
Examples
--------
>>> jit_kernel.show_source() # print kernel source
>>> jit_kernel.show_source("host") # print host source
>>> jit_kernel.show_source("both") # print both sources
"""
try:
if which == "kernel":
src = self.get_kernel_source()
print(src)
elif which == "host":
src = self.get_host_source()
# Host is generally C/C++
print(src)
elif which == "both":
print("===== Kernel Source =====")
ksrc = self.get_kernel_source()
print(ksrc)
print("===== Host Source =====")
hsrc = self.get_host_source()
print(hsrc)
else:
raise ValueError(f"Unknown option for 'which': {which}")
except Exception as e:
logger.error(f"Failed to show source code: {e}")
def export_sources(self, kernel_path: str | None = None, host_path: str | None = None) -> None:
"""
Export generated source code to files.
Parameters
----------
kernel_path : Optional[str]
Destination file path to write the kernel source. If None, skips writing kernel code.
host_path : Optional[str]
Destination file path to write the host source. If None, skips writing host code.
Examples
--------
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> jit_kernel.export_sources(host_path="/tmp/host.cc")
>>> jit_kernel.export_sources(
... kernel_path="/tmp/kernel.cu",
... host_path="/tmp/host.cc",
... )
"""
if kernel_path is None and host_path is None:
raise ValueError("At least one of kernel_path or host_path must be provided.")
try:
if kernel_path is not None:
dir_path = os.path.dirname(kernel_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(kernel_path, 'w') as f:
f.write(self.get_kernel_source())
if host_path is not None:
dir_path = os.path.dirname(host_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(host_path, 'w') as f:
f.write(self.get_host_source())
except Exception as e:
logger.error(f"Failed to export sources: {e}")
# Backward compatibility alias (deprecated)
def print_source_code(self,
which: Literal["kernel", "host", "both"] = "kernel",
file: str | None = None) -> None:
"""
Deprecated: use show_source() or export_sources() instead.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Kept for backward compatibility with printing behavior.
file : Optional[str]
If provided, behaves like export_sources(kernel_path=file).
Examples
--------
>>> # New API (preferred)
>>> jit_kernel.show_source("both")
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger.warning(
"print_source_code is deprecated; use show_source() or export_sources() instead.")
if file is not None:
# Historical behavior wrote only kernel source when file provided
self.export_sources(kernel_path=file)
else:
self.show_source(which=which)
def update_tuner_result(self, latency: float, config: dict[str, Any], def update_tuner_result(self, latency: float, config: dict[str, Any],
ref_latency: float) -> JITKernel: ref_latency: float) -> JITKernel:
""" """
...@@ -491,3 +597,131 @@ class JITKernel(Generic[_P, _T]): ...@@ -491,3 +597,131 @@ class JITKernel(Generic[_P, _T]):
# Export the compiled kernel function to a shared library file. # Export the compiled kernel function to a shared library file.
self.rt_module.export_library(kernel_file) self.rt_module.export_library(kernel_file)
def _get_ptx(self, verbose: bool | None = None) -> str:
"""
Compile and return PTX for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose NVRTC logs. Defaults to self.verbose.
Returns
-------
str
The compiled PTX text.
"""
if not is_cuda_target(self.target):
raise ValueError("PTX is only available for CUDA targets.")
# Prefer NVCC for PTX generation via contrib helper
code = self.get_kernel_source()
if verbose is None:
verbose = self.verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with self.target:
return tl_nvcc.get_ptx_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
def show_ptx(self) -> None:
"""
Print compiled PTX for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_ptx()
"""
try:
ptx = self._get_ptx()
print(ptx)
except Exception as e:
logger.error(f"Failed to show PTX: {e}")
def export_ptx(self, path: str) -> None:
"""
Export compiled PTX to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write PTX.
Examples
--------
>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
"""
if not path:
raise ValueError("path must be provided to export PTX")
try:
ptx = self._get_ptx()
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(path, "w") as f:
f.write(ptx)
logger.info(f"PTX saved to {os.path.abspath(path)}")
except Exception as e:
logger.error(f"Failed to export PTX: {e}")
def _get_sass(self, verbose: bool | None = None) -> str:
"""
Compile and return SASS for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose tool logs. Defaults to self.verbose.
Returns
-------
str
The disassembled SASS text.
"""
if not is_cuda_target(self.target):
raise ValueError("SASS is only available for CUDA targets.")
code = self.get_kernel_source()
if verbose is None:
verbose = self.verbose
with self.target:
return tl_nvcc.get_sass_from_source(
code, compile_flags=self.compile_flags, verbose=verbose)
def show_sass(self) -> None:
"""
Print disassembled SASS for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_sass()
"""
try:
sass = self._get_sass()
print(sass)
except Exception as e:
logger.error(f"Failed to show SASS: {e}")
def export_sass(self, path: str) -> None:
"""
Export disassembled SASS to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write SASS.
Examples
--------
>>> jit_kernel.export_sass("/tmp/kernel.sass")
"""
if not path:
raise ValueError("path must be provided to export SASS")
try:
sass = self._get_sass()
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(path, "w") as f:
f.write(sass)
logger.info(f"SASS saved to {os.path.abspath(path)}")
except Exception as e:
logger.error(f"Failed to export SASS: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment