Commit e32311b2 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Debug] Improve Memory Layout Plot (#136)

* Change default log level from WARNING to INFO in TileLang initialization

* Refactor Flash Attention Variable-Length MHA Example with Cython Backend Support

- Update `example_mha_fwd_varlen.py` to use Cython backend for kernel compilation
- Remove unused imports and simplify function signature
- Modify `flashattn` function to handle max sequence length as a separate argument
- Update kernel call to include max sequence length parameter
- Improve code readability and remove commented-out code
- Add print statement to confirm successful assertion

* Refactor code formatting in TileLang lowering and example files

- Improve line breaks and code formatting in `lower.py`, `wrapper.py`, and `tensor.py`
- Simplify line breaks and reduce unnecessary whitespace
- Enhance code readability by adjusting indentation and line breaks
- Update example MHA forward pass script with cleaner tensor initialization

* Update TileLang kernel test with import path changes for MMA layout and macro generator

- Modify import statements in test_tilelang_kernel_dequantize_gemm.py
- Replace bitblas imports with tilelang.intrinsics imports for MMA-related utilities
- Update main function to use tilelang.testing.main()

* Add Block Sparse Attention Examples for TileLang and Triton

- Implement block sparse attention kernels for both TileLang and Triton
- Add utility functions for generating sparse attention masks using top-k and threshold methods
- Support causal and variable-length attention scenarios
- Include test cases for different sequence length configurations
- Demonstrate block-level sparse attention with configurable parameters

* Refactor Block Sparse Attention Examples with Code Style Improvements

- Improve code formatting in block_sparse_attn_tilelang.py and block_sparse_attn_triton.py
- Enhance readability by adjusting line breaks and indentation
- Simplify kernel and function calls with better formatting
- Add whitespace and line break improvements for better code clarity

* Enhance Layout Plotting with Multi-Replication and Dynamic Visualization

- Update plot_layout function to support multiple replications in thread and value mapping
- Improve thread and value mapping to handle replicated layouts
- Dynamically adjust figure size and legend positioning
- Add print statements for saved plot file paths
- Modify example fragment_mma_load_a.py to uncomment and enable warp and block layout plotting
parent b70683b3
...@@ -102,15 +102,15 @@ from tilelang.tools import plot_layout ...@@ -102,15 +102,15 @@ from tilelang.tools import plot_layout
# ldmatrix layout 16x16 # ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False) base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout) print(base_layout)
plot_layout(base_layout, name="base_layout") plot_layout(base_layout, name="base_layout")
# # warp layout 32x16 # warp layout 32x16
# warp_layout = base_layout.repeat([block_rows, 1], warp_layout = base_layout.repeat([block_rows, 1],
# repeat_on_thread=True).replicate(block_cols) repeat_on_thread=True).replicate(block_cols)
# print(warp_layout) print(warp_layout)
# plot_layout(warp_layout, name="warp_layout") plot_layout(warp_layout, name="warp_layout")
# # block layout 128x32 # block layout 128x32
# block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False) block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout)
# plot_layout(block_layout, name="block_layout") # plot_layout(block_layout, name="block_layout")
...@@ -38,6 +38,7 @@ def plot_layout(layout: T.Layout, ...@@ -38,6 +38,7 @@ def plot_layout(layout: T.Layout,
# Get the input shape of the layout and convert it to a list of integers # Get the input shape of the layout and convert it to a list of integers
input_shape = layout.get_input_shape() input_shape = layout.get_input_shape()
input_shape = [int(var) for var in input_shape] input_shape = [int(var) for var in input_shape]
replicate_size = int(layout.replicate_size)
# Get the total number of threads # Get the total number of threads
num_threads = int(layout.get_thread_size()) num_threads = int(layout.get_thread_size())
...@@ -45,34 +46,40 @@ def plot_layout(layout: T.Layout, ...@@ -45,34 +46,40 @@ def plot_layout(layout: T.Layout,
import itertools import itertools
# Initialize a 2D array to store thread mappings # Initialize a 2D array to store thread mappings
thread_map = np.zeros(input_shape, dtype=int) thread_map = np.empty(input_shape, dtype=object)
for idx in np.ndindex(thread_map.shape):
thread_map[idx] = []
# Initialize a 2D array to store value mappings
value_map = np.zeros(input_shape, dtype=object)
for idx in np.ndindex(value_map.shape):
value_map[idx] = []
# Iterate over all possible indices in the input shape # Iterate over all possible indices in the input shape
for idx in itertools.product(*[range(dim) for dim in input_shape]): for i in range(replicate_size):
index = list(idx) for idx in itertools.product(*[range(dim) for dim in input_shape]):
# If replication is enabled, adjust the index index = list(idx)
if layout.replicate_size > 1: # If replication is enabled, adjust the index
index.insert(0, 0) if replicate_size > 1:
# Map the index to a thread ID index.insert(0, i)
thread_id = layout.map_forward_thread(index) # Map the index to a thread ID
assert len(thread_id) == 1 # Ensure a single-thread mapping thread_id = layout.map_forward_thread(index)
thread_map[idx] = int(thread_id[0]) # Store the thread ID assert len(thread_id) == 1 # Ensure a single-thread mapping
thread_map[idx].append(int(thread_id[0])) # Store the thread ID
# Initialize a 2D array to store value mappings
value_map = np.zeros(input_shape, dtype=int)
# Iterate again to map values # Iterate again to map values
for idx in itertools.product(*[range(dim) for dim in input_shape]): for i in range(replicate_size):
index = list(idx) for idx in itertools.product(*[range(dim) for dim in input_shape]):
if layout.replicate_size > 1: index = list(idx)
index.insert(0, 0) if replicate_size > 1:
thread_id = layout.map_forward_thread(index) index.insert(0, i)
value_id = layout.map_forward_index(index) thread_id = layout.map_forward_thread(index)
assert len(value_id) == 1 # Ensure a single-value mapping value_id = layout.map_forward_index(index)
value_map[idx] = int(value_id[0]) # Store the value ID assert len(value_id) == 1 # Ensure a single-value mapping
value_map[idx].append(int(value_id[0])) # Store the value ID
# Load the colormap with twice as many colors as the number of threads # Load the colormap with twice as many colors as the number of threads
cmap = plt.get_cmap(colormap, num_threads * 2) cmap = plt.get_cmap(colormap, num_threads * 2 // replicate_size)
# Generate a list of colors based on the colormap # Generate a list of colors based on the colormap
raw_colors = [cmap(i) for i in range(num_threads)] raw_colors = [cmap(i) for i in range(num_threads)]
...@@ -80,19 +87,21 @@ def plot_layout(layout: T.Layout, ...@@ -80,19 +87,21 @@ def plot_layout(layout: T.Layout,
# Determine the number of rows and columns in the input shape # Determine the number of rows and columns in the input shape
nrows, ncols = input_shape nrows, ncols = input_shape
plt.figure(figsize=(nrows, ncols)) # Set the figure size # Adjust figure size to maintain square cells
cell_size = 1 # Base size for each cell
plt.figure(figsize=(cell_size * ncols, cell_size * nrows)) # Set the figure size proportionally
ax = plt.gca() # Get the current axis ax = plt.gca() # Get the current axis
font_size = 24 # Set font size for text annotatio font_size = 24 # Set font size for text annotation
# Iterate through each row and column # Iterate through each row and column
for i in range(nrows): for i in range(nrows):
for j in range(ncols): for j in range(ncols):
thread_id = thread_map[i, j] # Get the thread ID thread_ids = thread_map[i, j] # Get the thread ID
local_id = value_map[i, j] # Get the value ID local_ids = value_map[i, j] # Get the value ID
if verbose: if verbose:
print(f"thread_map[{i}, {j}] = {thread_id} value_map[{i}, {j}] = {local_id}") print(f"thread_map[{i}, {j}] = {thread_ids} value_map[{i}, {j}] = {local_ids}")
color = colors[thread_id] # Select color based on thread ID color = colors[thread_ids[0]] # Select color based on thread ID
# Create a rectangle patch for visualization # Create a rectangle patch for visualization
rect = patches.Rectangle((j, i), rect = patches.Rectangle((j, i),
1, 1,
...@@ -103,9 +112,23 @@ def plot_layout(layout: T.Layout, ...@@ -103,9 +112,23 @@ def plot_layout(layout: T.Layout,
ax.add_patch(rect) # Add the rectangle to the plot ax.add_patch(rect) # Add the rectangle to the plot
# Add text annotations inside the rectangles # Add text annotations inside the rectangles
text = f"T{thread_id}\nL{local_id}" thread_str = []
ax.text( for thread_id in thread_ids:
j + 0.5, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size) thread_str.append(f"{thread_id}")
thread_str = "T" + "/".join(thread_str)
local_id = local_ids[0]
# assert local id in local_ids is equal
assert all(local_id == local_id for local_id in local_ids)
# Calculate thread font size based on string length
thread_fontsize = min(font_size, font_size * (4 / len(thread_str)))
# Add thread ID text with adjusted font size
ax.text(j + 0.5, i + 0.3, thread_str,
ha='center', va='center', color='black', fontsize=thread_fontsize)
# Add local ID text with original font size
ax.text(j + 0.5, i + 0.7, f"L{local_id}",
ha='center', va='center', color='black', fontsize=font_size)
# Add row labels to the left side of the plot # Add row labels to the left side of the plot
for i in range(nrows): for i in range(nrows):
...@@ -132,6 +155,13 @@ def plot_layout(layout: T.Layout, ...@@ -132,6 +155,13 @@ def plot_layout(layout: T.Layout,
plt.xticks([]) # Remove x-axis ticks plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks plt.yticks([]) # Remove y-axis ticks
# Calculate legend position based on figure size
fig = plt.gcf()
fig_width = fig.get_size_inches()[0]
fig_height = fig.get_size_inches()[1]
legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width
legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height
legend_patches = [ legend_patches = [
patches.Patch(color='black', label="T: Thread ID"), patches.Patch(color='black', label="T: Thread ID"),
patches.Patch(color='black', label="L: Local ID") patches.Patch(color='black', label="L: Local ID")
...@@ -141,7 +171,7 @@ def plot_layout(layout: T.Layout, ...@@ -141,7 +171,7 @@ def plot_layout(layout: T.Layout,
loc="upper right", loc="upper right",
fontsize=font_size - 4, fontsize=font_size - 4,
frameon=False, frameon=False,
bbox_to_anchor=(1.0, 1.12), bbox_to_anchor=(legend_x, legend_y), # Dynamic position
ncols=2) ncols=2)
# Create the output directory if it does not exist # Create the output directory if it does not exist
...@@ -155,11 +185,14 @@ def plot_layout(layout: T.Layout, ...@@ -155,11 +185,14 @@ def plot_layout(layout: T.Layout,
# Save as PDF # Save as PDF
pdf_path = tmp_directory / f"{name}.pdf" pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight") plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")
# Save as PNG # Save as PNG
png_path = tmp_directory / f"{name}.png" png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")
# Save as SVG # Save as SVG
svg_path = tmp_directory / f"{name}.svg" svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg") plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment