Unverified Commit 924225ed authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[Tool] Provide layout visualization tool (#1353)

* Provide layout visualization tool

Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation.

This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance.

The visualization can be enabled via a pass config option.

* format

* add layout visual example

* Adds vis extra with matplotlib dependency

* rafactor pass config name

* fix lint

* Enables configurable layout visualization formats

Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option.

This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs.

It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false".

* Adds visual layout inference tool docs

* fix lint

* fix lint

* Rafactor configurable layout visualization formats

* fix lint

* fix typo

* add some comments

* fix lints

* add some warnings for user

* Moves layout visualization

* Refactors layout visualization pass configuration

Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats.

* Enables multiple layout visualization formats

* Updates layout visualization docs

* Moves layout visualization to analysis
parent f8e7fef5
......@@ -171,6 +171,32 @@ The output messages will include something like:
msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0
```
### Visual Layout Inference For TileLang
The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations.
When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates:
1. **Textual output**: A human-readable description of the layout mapping
2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping
The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation.
When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats:
- "txt": Text output only (same as default)
- "all": Generates all formats (TXT, PDF, PNG, SVG)
- "png": Generate PNG format only
- "pdf": Generate PDF format only
- "svg": Generate SVG format only
- "txt,svg": Generate multiple formats (comma-separated) in addition to text output
The output messages of "txt" will include something like:
```
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
```
## Conclusion
By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs.
......
import tilelang
import tilelang.language as T
# use pass_configs to enable layout visualization
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True,
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg"
})
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def main():
kernel = matmul(128, 128, 128, 32, 32, 32)
import torch
a = torch.randn(128, 128).cuda().half()
b = torch.randn(128, 128).cuda().half()
c = kernel(a, b)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# print the layout visualization result and save figures to ./tmp.
'''
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''
if __name__ == "__main__":
main()
......@@ -42,6 +42,8 @@ dependencies = [
# mldtypes should be greater than 0.5.1
# if you want to enable fp4
fp4 = ["ml-dtypes>=0.5.1"]
# if you want to enable layout inference visualization
vis = ["matplotlib"]
[build-system]
requires = ["cython>=3.0.0", "scikit-build-core"]
......
......@@ -34,6 +34,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......
......@@ -51,6 +51,10 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
static constexpr const char *kLayoutVisualizationEnable =
"tl.layout_visualization_enable";
static constexpr const char *kLayoutVisualizationFormats =
"tl.layout_visualization_formats";
/*!
* \brief Whether to disable dynamic tail split
*
......
......@@ -137,6 +137,7 @@ from . import (
transform, # noqa: F401
language, # noqa: F401
engine, # noqa: F401
tools, # noqa: F401
)
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401
......
......@@ -3,3 +3,4 @@
from .ast_printer import ASTPrinter # noqa: F401
from .nested_loop_checker import NestedLoopChecker # noqa: F401
from .fragment_loop_checker import FragmentLoopChecker # noqa: F401
from .layout_visual import LayoutVisual # noqa: F401
import tilelang.language as T
from tvm import tir
from tvm.tir import PyStmtExprVisitor
from tvm.tir.transform import prim_func_pass
from tilelang.tools.plot_layout import plot_layout
def print_fragment_format(layout: T.Fragment) -> str:
"""
Format fragment layout information into a human-readable string.
Parameters
----------
layout : T.Fragment
The fragment layout to format
Returns
-------
str
Formatted string showing shape, thread mapping, and index mapping
"""
if isinstance(layout, T.Fragment):
input_shape = layout.get_input_shape()
output_shape = layout.get_output_shape()
lines = [
f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}",
f" Index: {layout.forward_index}"
]
print("\n".join(lines))
else:
raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")
@tir.functor.visitor
class _LayoutVisualVisitor(PyStmtExprVisitor):
"""
User-friendly pass which visualizes fragment layouts inferred during compilation.
In TileLang, Fragment layouts describe:
- How logical indices (e.g., [i, j]) map to thread IDs
- How logical indices map to register file locations within each thread
- The shape transformation from input dimensions to output dimensions
This pass generates two types of output:
1. Textual output: A human-readable description printed to console
2. Visual diagrams: Color-coded plots saved to files (PDF, PNG, SVG formats)
Configuration:
The pass is controlled by the TL_ENABLE_LAYOUT_VISUALIZATION configuration option.
The configuration accepts string values:
- Empty string or not set: Pass does nothing (default, disabled)
- "png": Generate PNG format only (recommended for quick inspection)
- "pdf": Generate PDF format only (recommended for documentation)
- "svg": Generate SVG format only (recommended for web/vector graphics)
- "all": Generate all formats (PDF, PNG, SVG)
- "png,svg": Generate multiple formats (comma-separated)
"""
def __init__(self, formats: list[str] = ""):
super().__init__()
self.layout_found = []
self.processed_layouts = set()
self.formats_list = [f for f in formats if f != "txt"]
def visit_block_(self, op: tir.Block) -> None:
if "layout_map" in op.annotations:
layout_map = op.annotations["layout_map"]
for key, layout in layout_map.items():
if isinstance(layout, T.Fragment):
layout_id = str(layout)
if layout_id not in self.processed_layouts:
print(f"{key} inferenced layout:")
print_fragment_format(layout)
for fmt in self.formats_list:
plot_layout(layout, name=f"{key}_layout", formats=fmt)
self.processed_layouts.add(layout_id)
# super().visit_block_(op)
def LayoutVisual(formats: str = ""):
def pass_fn(func: tir.PrimFunc, mod, ctx):
_LayoutVisualVisitor(formats=formats).visit_stmt(func.body)
return func
return prim_func_pass(pass_fn, opt_level=0)
......@@ -67,6 +67,48 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
return enabled
def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
if not formats_value:
return ["txt"]
formats_str = formats_value.strip().lower()
valid_formats = ["txt", "png", "pdf", "svg", "all"]
if formats_str == "all":
return ["txt", "png", "pdf", "svg"]
if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
else:
formats_list = [formats_str]
invalid_formats = [f for f in formats_list if f not in valid_formats]
if invalid_formats:
raise ValueError(
f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. "
f"Valid formats are: {valid_formats}. "
f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')"
)
return formats_list
def LayoutVisual(mod: IRModule) -> None:
"""Apply layout visualization pass if enabled."""
if should_enable_layout_visual():
formats = get_layout_visual_formats()
tilelang.analysis.LayoutVisual(formats=formats)(mod)
def PreLowerSemanticCheck(mod: IRModule) -> None:
"""
Check whether the module is valid before lowering. If not, raise a user-friendly error
......@@ -121,6 +163,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LayoutReducer()(mod)
# Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod)
# Visualize the layout
LayoutVisual(mod)
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
......
from __future__ import annotations
import tilelang.language as T
def plot_layout(layout: T.Layout,
def plot_layout(layout: T.Fragment,
save_directory="./tmp",
name: str = "layout",
colormap: str = "RdPu",
verbose: bool = False) -> None:
verbose: bool = False,
formats: str | list[str] = "png") -> None:
"""
Plot the layout of a buffer.
......@@ -21,7 +23,8 @@ def plot_layout(layout: T.Layout,
The colormap to use for visualization (default is "RdPu").
verbose : bool, optional
If True, prints additional information about the mapping (default is False).
formats : str | list[str], optional
The formats to save the image in (default is "png").
Returns
-------
None
......@@ -82,6 +85,21 @@ def plot_layout(layout: T.Layout,
raw_colors = [cmap(i) for i in range(num_threads)]
colors = raw_colors.copy()
# Show the distribution of registers in each thread of a warp.
warp_size = 32
# Warn if the number of threads is less than the warp size
if num_threads < warp_size:
import warnings
warnings.warn(
f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). "
f"For the best viewing experience, it is recommended to have at least {warp_size} threads.",
UserWarning,
stacklevel=2)
spectral_camp = plt.get_cmap("hsv", warp_size * 6)
for i in range(min(warp_size, num_threads)):
colors[i] = spectral_camp(i * 6)
# Determine the number of rows and columns in the input shape
nrows, ncols = input_shape
# Adjust figure size to maintain square cells
......@@ -191,17 +209,30 @@ def plot_layout(layout: T.Layout,
# Save the figure in multiple formats
plt.tight_layout()
# Save as PDF
if isinstance(formats, str):
formats_str = formats.strip().lower()
if formats_str == 'all':
formats_list = ['pdf', 'png', 'svg']
elif "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
else:
formats_list = [formats_str]
else:
raise TypeError(f"Expected str, but got {type(formats).__name__}. "
f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.")
# Save the figure
if 'pdf' in formats_list:
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")
# Save as PNG
if 'png' in formats_list:
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")
# Save as SVG
if 'svg' in formats_list:
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
......@@ -69,6 +69,15 @@ class PassConfigKey(str, Enum):
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""
TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable"
"""Enable layout inference visualization. Default: False"""
TL_LAYOUT_VISUALIZATION_FORMATS = "tl.layout_visualization_formats"
"""Layout visualization formats.
Acceptable values: "pdf", "png", "svg", "all"
"""
TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace"
"""Control StorageRewrite inplace detection.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment