Commit e32311b2 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Debug] Improve Memory Layout Plot (#136)

* Change default log level from WARNING to INFO in TileLang initialization

* Refactor Flash Attention Variable-Length MHA Example with Cython Backend Support

- Update `example_mha_fwd_varlen.py` to use Cython backend for kernel compilation
- Remove unused imports and simplify function signature
- Modify `flashattn` function to handle max sequence length as a separate argument
- Update kernel call to include max sequence length parameter
- Improve code readability and remove commented-out code
- Add print statement to confirm successful assertion

* Refactor code formatting in TileLang lowering and example files

- Improve line breaks and code formatting in `lower.py`, `wrapper.py`, and `tensor.py`
- Simplify line breaks and reduce unnecessary whitespace
- Enhance code readability by adjusting indentation and line breaks
- Update example MHA forward pass script with cleaner tensor initialization

* Update TileLang kernel test with import path changes for MMA layout and macro generator

- Modify import statements in test_tilelang_kernel_dequantize_gemm.py
- Replace bitblas imports with tilelang.intrinsics imports for MMA-related utilities
- Update main function to use tilelang.testing.main()

* Add Block Sparse Attention Examples for TileLang and Triton

- Implement block sparse attention kernels for both TileLang and Triton
- Add utility functions for generating sparse attention masks using top-k and threshold methods
- Support causal and variable-length attention scenarios
- Include test cases for different sequence length configurations
- Demonstrate block-level sparse attention with configurable parameters

* Refactor Block Sparse Attention Examples with Code Style Improvements

- Improve code formatting in block_sparse_attn_tilelang.py and block_sparse_attn_triton.py
- Enhance readability by adjusting line breaks and indentation
- Simplify kernel and function calls with better formatting
- Add whitespace and line break improvements for better code clarity

* Enhance Layout Plotting with Multi-Replication and Dynamic Visualization

- Update plot_layout function to support multiple replications in thread and value mapping
- Improve thread and value mapping to handle replicated layouts
- Dynamically adjust figure size and legend positioning
- Add print statements for saved plot file paths
- Modify example fragment_mma_load_a.py to uncomment and enable warp and block layout plotting
parent b70683b3
......@@ -102,15 +102,15 @@ from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# # warp layout 32x16
# warp_layout = base_layout.repeat([block_rows, 1],
# repeat_on_thread=True).replicate(block_cols)
# print(warp_layout)
# plot_layout(warp_layout, name="warp_layout")
# warp layout 32x16
warp_layout = base_layout.repeat([block_rows, 1],
repeat_on_thread=True).replicate(block_cols)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
# # block layout 128x32
# block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
# block layout 128x32
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout)
# plot_layout(block_layout, name="block_layout")
......@@ -38,6 +38,7 @@ def plot_layout(layout: T.Layout,
# Get the input shape of the layout and convert it to a list of integers
input_shape = layout.get_input_shape()
input_shape = [int(var) for var in input_shape]
replicate_size = int(layout.replicate_size)
# Get the total number of threads
num_threads = int(layout.get_thread_size())
......@@ -45,34 +46,40 @@ def plot_layout(layout: T.Layout,
import itertools
# Initialize a 2D array to store thread mappings
thread_map = np.zeros(input_shape, dtype=int)
thread_map = np.empty(input_shape, dtype=object)
for idx in np.ndindex(thread_map.shape):
thread_map[idx] = []
# Initialize a 2D array to store value mappings
value_map = np.zeros(input_shape, dtype=object)
for idx in np.ndindex(value_map.shape):
value_map[idx] = []
# Iterate over all possible indices in the input shape
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
# If replication is enabled, adjust the index
if layout.replicate_size > 1:
index.insert(0, 0)
# Map the index to a thread ID
thread_id = layout.map_forward_thread(index)
assert len(thread_id) == 1 # Ensure a single-thread mapping
thread_map[idx] = int(thread_id[0]) # Store the thread ID
# Initialize a 2D array to store value mappings
value_map = np.zeros(input_shape, dtype=int)
for i in range(replicate_size):
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
# If replication is enabled, adjust the index
if replicate_size > 1:
index.insert(0, i)
# Map the index to a thread ID
thread_id = layout.map_forward_thread(index)
assert len(thread_id) == 1 # Ensure a single-thread mapping
thread_map[idx].append(int(thread_id[0])) # Store the thread ID
# Iterate again to map values
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
if layout.replicate_size > 1:
index.insert(0, 0)
thread_id = layout.map_forward_thread(index)
value_id = layout.map_forward_index(index)
assert len(value_id) == 1 # Ensure a single-value mapping
value_map[idx] = int(value_id[0]) # Store the value ID
for i in range(replicate_size):
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
if replicate_size > 1:
index.insert(0, i)
thread_id = layout.map_forward_thread(index)
value_id = layout.map_forward_index(index)
assert len(value_id) == 1 # Ensure a single-value mapping
value_map[idx].append(int(value_id[0])) # Store the value ID
# Load the colormap with twice as many colors as the number of threads
cmap = plt.get_cmap(colormap, num_threads * 2)
cmap = plt.get_cmap(colormap, num_threads * 2 // replicate_size)
# Generate a list of colors based on the colormap
raw_colors = [cmap(i) for i in range(num_threads)]
......@@ -80,19 +87,21 @@ def plot_layout(layout: T.Layout,
# Determine the number of rows and columns in the input shape
nrows, ncols = input_shape
plt.figure(figsize=(nrows, ncols)) # Set the figure size
# Adjust figure size to maintain square cells
cell_size = 1 # Base size for each cell
plt.figure(figsize=(cell_size * ncols, cell_size * nrows)) # Set the figure size proportionally
ax = plt.gca() # Get the current axis
font_size = 24 # Set font size for text annotatio
font_size = 24 # Set font size for text annotation
# Iterate through each row and column
for i in range(nrows):
for j in range(ncols):
thread_id = thread_map[i, j] # Get the thread ID
local_id = value_map[i, j] # Get the value ID
thread_ids = thread_map[i, j] # Get the thread ID
local_ids = value_map[i, j] # Get the value ID
if verbose:
print(f"thread_map[{i}, {j}] = {thread_id} value_map[{i}, {j}] = {local_id}")
print(f"thread_map[{i}, {j}] = {thread_ids} value_map[{i}, {j}] = {local_ids}")
color = colors[thread_id] # Select color based on thread ID
color = colors[thread_ids[0]] # Select color based on thread ID
# Create a rectangle patch for visualization
rect = patches.Rectangle((j, i),
1,
......@@ -103,9 +112,23 @@ def plot_layout(layout: T.Layout,
ax.add_patch(rect) # Add the rectangle to the plot
# Add text annotations inside the rectangles
text = f"T{thread_id}\nL{local_id}"
ax.text(
j + 0.5, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)
thread_str = []
for thread_id in thread_ids:
thread_str.append(f"{thread_id}")
thread_str = "T" + "/".join(thread_str)
local_id = local_ids[0]
# assert local id in local_ids is equal
assert all(local_id == local_id for local_id in local_ids)
# Calculate thread font size based on string length
thread_fontsize = min(font_size, font_size * (4 / len(thread_str)))
# Add thread ID text with adjusted font size
ax.text(j + 0.5, i + 0.3, thread_str,
ha='center', va='center', color='black', fontsize=thread_fontsize)
# Add local ID text with original font size
ax.text(j + 0.5, i + 0.7, f"L{local_id}",
ha='center', va='center', color='black', fontsize=font_size)
# Add row labels to the left side of the plot
for i in range(nrows):
......@@ -132,6 +155,13 @@ def plot_layout(layout: T.Layout,
plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks
# Calculate legend position based on figure size
fig = plt.gcf()
fig_width = fig.get_size_inches()[0]
fig_height = fig.get_size_inches()[1]
legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width
legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height
legend_patches = [
patches.Patch(color='black', label="T: Thread ID"),
patches.Patch(color='black', label="L: Local ID")
......@@ -141,7 +171,7 @@ def plot_layout(layout: T.Layout,
loc="upper right",
fontsize=font_size - 4,
frameon=False,
bbox_to_anchor=(1.0, 1.12),
bbox_to_anchor=(legend_x, legend_y), # Dynamic position
ncols=2)
# Create the output directory if it does not exist
......@@ -155,11 +185,14 @@ def plot_layout(layout: T.Layout,
# Save as PDF
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")
# Save as PNG
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")
# Save as SVG
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment