Commit f9b6a92e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Tools] Introduce `plot_layout` to visualize the fragment layout (#68)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University

* enhance acknowledgement

* Replace tutorial on memory layout optimization with new tutorial on writing high-performance kernels with thread primitives

* Update subproject commit for TVM dependency

* Update subproject commit for TVM dependency

* Add int4_t type and functions for packing char values in CUDA common header

* Add plot_layout example and implement GetForwardVars method in layout classes

* Refactor code for improved readability by adjusting line breaks and formatting in layout and test files

* Fix formatting by removing unnecessary line break in layout.h

* Refactor make_int4 function for improved readability by adjusting parameter formatting
parent 0677e542
......@@ -200,4 +200,4 @@ Welcome to join our Discord community for discussions, support, and collaboratio
## Acknowledgements
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) under the supervision of [zhi yang](https://yangzhihome.github.io) at Peking university. Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
We would like to express our gratitude to the [TVM](https://github.com/apache/tvm) community for their invaluable contributions. The initial version of this project was mainly developed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) with supervision from Prof. [Zhi Yang](https://yangzhihome.github.io) at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions.
Writing High-Performance Kernels with Thread Primitives
=======================================================
Annotating Memory Layout for Optimization
=========================================
The following example demonstrates how to generate and visualize a memory layout using tilelang tools `plot_layout`.
Example Code
```python
from tilelang.tools import plot_layout
from tilelang.layouts import make_mma_load_base_layout # Ensure this function is imported
# Create a 16×16 matrix layout for ldmatrix operations
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
# Print the layout structure (optional for debugging)
print(base_layout)
# Plot and save the layout visualization
plot_layout(base_layout, name="base_layout")
```
Output
![base_layout](./images/base_layout.png)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
def make_mma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr,
shared_16x16_to_mma_32x8_layout_rs,
shared_16x32_to_mma_32x16_layout,
shared_32x16_to_mma_32x16_layout,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
assert transposed is False, "transposed is not supported yet"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None
transform_func_rs: Callable = None
if dtype_bits == 16:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout
transform_func_rs = shared_32x16_to_mma_32x16_layout
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
return base_fragment
block_rows = 2
block_cols = 2
warp_rows = 4
warp_cols = 4
chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# # warp layout 32x16
# warp_layout = base_layout.repeat([block_rows, 1],
# repeat_on_thread=True).replicate(block_cols)
# print(warp_layout)
# plot_layout(warp_layout, name="warp_layout")
# # block layout 128x32
# block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
# plot_layout(block_layout, name="block_layout")
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -87,6 +87,14 @@ void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
}
}
Array<PrimExpr> LayoutNode::GetForwardVars() const {
Array<PrimExpr> vars;
for (size_t i = 0; i < InputDim(); i++) {
vars.push_back(InputPlaceholder(i));
}
return vars;
}
Array<PrimExpr> LayoutNode::OutputShape() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
......@@ -307,6 +315,17 @@ PrimExpr FragmentNode::ThreadExtent() const {
return ist.max();
}
Array<PrimExpr> FragmentNode::GetForwardVars() const {
Array<PrimExpr> vars;
if (*as_const_int(ReplicateExtent()) > 1) {
vars.push_back(ReplicationPlaceholder());
}
for (size_t i = 0; i < InputDim(); i++) {
vars.push_back(InputPlaceholder(i));
}
return vars;
}
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
const Optional<PrimExpr> &rep_var) const {
Map<Var, PrimExpr> vmap;
......@@ -396,6 +415,10 @@ TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
return layout->GetForwardIndex();
});
TVM_REGISTER_GLOBAL("tl.Layout_forward_vars").set_body_typed([](Layout layout) {
return layout->GetForwardVars();
});
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Fragment(args[0], args[1], args[2], args[3]);
});
......
......@@ -34,6 +34,8 @@ public:
Array<PrimExpr> GetForwardIndex() const { return forward_index_; }
virtual Array<PrimExpr> GetForwardVars() const;
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const;
......@@ -72,6 +74,8 @@ public:
PrimExpr GetForwardThread() const { return forward_thread_; }
Array<PrimExpr> GetForwardVars() const final;
Layout Inverse() const final;
PrimExpr ThreadExtent() const;
......
......@@ -11,6 +11,8 @@ using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
using int4_t = int4;
#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
......@@ -44,6 +46,27 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}
// Pack four char values
TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2,
signed char x3) {
return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0;
}
// Pack sixteen char values.
TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
signed char x3, signed char y0, signed char y1,
signed char y2, signed char y3, signed char z0,
signed char z1, signed char z2, signed char z3,
signed char w0, signed char w1, signed char w2,
signed char w3) {
int4_t result;
result.x = make_int(x0, x1, x2, x3);
result.y = make_int(y0, y1, y2, y3);
result.z = make_int(z0, z1, z2, z3);
result.w = make_int(w0, w1, w2, w3);
return result;
}
// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
......
......@@ -40,15 +40,15 @@ def matmul_ssr(
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, ko * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
P.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -104,6 +104,10 @@ def run_matmul_ssr(
def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(
16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(
128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(
1024,
1024,
......@@ -117,7 +121,7 @@ def test_gemm_f16f16f16_nt_ssr():
128,
32,
2,
)
num_threads=128)
def matmul_rsr(
......@@ -155,15 +159,15 @@ def matmul_rsr(
A_local = T.alloc_fragment(A_local_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, ko * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_local)
P.gemm(A_local, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_local, B_shared, C_local, trans_A, trans_B)
......@@ -359,4 +363,19 @@ def run_matmul_rrr(
# )
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
run_matmul_rsr(
128,
128,
128,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
0,
num_threads=128,
)
......@@ -5,13 +5,23 @@
import tvm
from tvm.ir import Range
from tvm.tir import IterVar, Var
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from tilelang.layout import Layout
from typing import List
@tvm._ffi.register_object("tl.Fragment")
class Fragment(Layout):
"""
A Fragment layout object that encapsulates iteration variables (forward_vars),
thread iteration variables (forward_thread), and index transformations
(forward_index). This class supports replication (thread_replicate) and
index mapping for fine-grained control over multi-dimensional data layouts.
"""
# Disable the linter warning about not calling super().__init__()
# because this object is created via TVM's FFI constructor mechanism.
# pylint: disable=super-init-not-called
def __init__(self,
shape,
......@@ -19,17 +29,51 @@ class Fragment(Layout):
forward_thread_fn=None,
replicate=1,
forward_index_fn=None):
"""
Initialize the Fragment with iteration variables and optional thread replication.
Parameters
----------
shape : list[int]
A list of integer sizes for each dimension of this fragment.
forward_fn : callable, optional
A function that takes the iteration variables, plus optionally a replicate
IterVar, and returns a tuple: (forward_thread, forward_index).
It is used when you want to compute both thread mapping and index mapping
from the shape variables.
forward_thread_fn : callable, optional
A function that takes iteration variables (plus optionally a replicate Var)
and returns an IterVar representing the thread index. This is used if
`forward_fn` is not provided, and only the thread mapping is derived
here while the index mapping is derived separately via `forward_index_fn`.
replicate : int, optional
How many times to replicate the iteration over the threads, typically
used for multi-threading or replication in the hardware threads. Defaults to 1.
forward_index_fn : callable, optional
A function that takes iteration variables and returns an index or list
of indices for this fragment. Used when `forward_fn` is None and
the index transformation is derived separately.
"""
# Create a list of IterVar objects based on shape dimensions
# Each dimension is assigned a range from 0..size and a Var like i0, i1, etc.
forward_vars = []
for idx, size in enumerate(shape):
iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0)
forward_vars.append(iv)
# Collect the underlying variables (i.e., Var objects) from the IterVars
vars = [iv.var for iv in forward_vars]
# Initialize placeholders for optional outputs
forward_thread: IterVar = None
forward_index: tvm.ir.container.Array = None
thread_replicate: IterVar = None
# If a forward_fn is provided, use it to derive both thread mapping and indices
if forward_fn is not None:
# If replication is greater than 1, create a replicate IterVar
# and pass it to forward_fn
if replicate > 1:
thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0)
forward_thread, forward_index = forward_fn(*vars, thread_replicate)
......@@ -37,7 +81,9 @@ class Fragment(Layout):
thread_replicate = None
forward_thread, forward_index = forward_fn(*vars)
else:
# If no forward_fn is provided, compute forward_index (if any) via forward_index_fn
forward_index = forward_index_fn(*vars) if forward_index_fn else None
# Then compute forward_thread via forward_thread_fn
if replicate > 1:
thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0)
forward_thread = forward_thread_fn(*vars, thread_replicate.var)
......@@ -45,9 +91,11 @@ class Fragment(Layout):
thread_replicate = None
forward_thread = forward_thread_fn(*vars)
# Ensure forward_index is an array if it isn't None
if forward_index is not None and not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index]
# Call TVM FFI constructor to set up internal data structures
self.__init_handle_by_constructor__(
_ffi_api.Fragment,
forward_vars,
......@@ -58,24 +106,104 @@ class Fragment(Layout):
@property
def thread(self):
"""
Returns the forward_thread (IterVar) of the Fragment, representing
the thread dimension or mapping.
"""
return _ffi_api.Fragment_thread(self)
def get_thread_size(self):
"""
Returns the extent (range size) of the thread dimension.
If the Fragment was replicated over threads, this will reflect
the number of threads.
"""
return _ffi_api.Fragment_thread_size(self)
def repeat(self,
repeats,
repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> "Fragment":
"""
Returns a new Fragment that repeats the iteration space a given number of times.
Parameters
----------
repeats : int
Number of times to repeat.
repeat_on_thread : bool, optional
If set, the repeat will happen on the thread dimension.
lower_dim_first : bool, optional
If set to True, repeat on lower dimensions first.
Returns
-------
Fragment
A new Fragment with the repeated iteration space.
"""
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
def replicate(self, replicate: int) -> "Fragment":
"""
Replicate the Fragment across a new thread dimension.
Parameters
----------
replicate : int
The replication factor or number of threads.
Returns
-------
Fragment
A new Fragment with an additional replicate dimension.
"""
return _ffi_api.Fragment_replicate(self, replicate)
def condense_rep_var(self) -> "Fragment":
"""
Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable
is no longer needed as a separate dimension.
Returns
-------
Fragment
A new Fragment where the replicate variable is condensed.
"""
return _ffi_api.Fragment_condense_rep_var(self)
def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr:
"""
Get the thread mapping expression for a given set of argument indices.
Parameters
----------
indices : list of PrimExpr
Indices for which to compute the thread mapping.
Returns
-------
PrimExpr
The computed thread expression for the provided indices.
"""
# Retrieve the forward iteration variables
forward_vars = self.get_forward_vars()
# The thread dimension (IterVar) is accessed via the `thread` property
forward_thread = self.thread
# Construct an IndexMap to map the provided args into the final thread index
index_map = IndexMap(
initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None)
return index_map.map_indices(indices)
def __repr__(self):
"""
String representation of the Fragment for debugging and logging.
Returns
-------
str
A string showing the thread dimension and the index dimension.
"""
return f"Fragment<thread={self.thread}, index={self.index}>"
......
......@@ -5,33 +5,129 @@
import tvm
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from typing import List
# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm._ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
forward_vars = []
"""
Initialize a Layout object.
Parameters
----------
shape : list of int
The shape of the layout, defining the number of elements along each dimension.
forward_fn : function
A function that maps index variables to their computed forward index.
"""
forward_vars = [] # List to store IterVars corresponding to each shape dimension
# Create an IterVar for each dimension in the shape
for idx, size in enumerate(shape):
# Define an IterVar over the range [0, size) with an associated variable name
iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0)
forward_vars.append(iv)
# Extract the variable references from the IterVars
vars = [iv.var for iv in forward_vars]
# Compute the forward index using the provided forward function
forward_index = forward_fn(*vars)
# Ensure forward_index is a list (to handle cases where a single expression is returned)
if isinstance(forward_index, PrimExpr):
forward_index = [forward_index]
# Call the FFI constructor to create the Layout object in C++ backend
self.__init_handle_by_constructor__(_ffi_api.Layout, forward_vars, forward_index)
@property
def index(self):
"""
Property to retrieve the forward index of the layout.
Returns
-------
PrimExpr or List[PrimExpr]
The computed forward index expression(s).
"""
return _ffi_api.Layout_index(self)
def get_input_shape(self):
"""
Get the input shape of the layout.
Returns
-------
List[int]
The shape of the input layout.
"""
return _ffi_api.Layout_input_shape(self)
def get_output_shape(self):
"""
Get the output shape of the layout.
Returns
-------
List[int]
The shape of the output layout.
"""
return _ffi_api.Layout_output_shape(self)
def get_forward_vars(self):
"""
Retrieve the iteration variables associated with the layout.
Returns
-------
List[IterVar]
A list of iteration variables that define the layout transformation.
"""
return _ffi_api.Layout_forward_vars(self)
def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr:
"""
Compute the forward index mapping for a given set of input indices.
Parameters
----------
indices : list of PrimExpr
The input indices to be mapped to their corresponding output indices.
Returns
-------
PrimExpr
The mapped index expression for the provided input indices.
"""
# Retrieve the iteration variables used in the layout transformation
forward_vars = self.get_forward_vars()
# Retrieve the computed forward index expressions
forward_indexes = self.index
# Construct an IndexMap to map the input indices to the computed output indices
index_map = IndexMap(
initial_indices=forward_vars, # The original iteration variables
final_indices=forward_indexes, # The computed forward indices
inverse_index_map=None # No inverse mapping provided at this stage
)
# Map the provided indices using the constructed index mapping
return index_map.map_indices(indices)
def inverse(self) -> "Layout":
"""
Compute the inverse of the current layout transformation.
Returns
-------
Layout
A new Layout object representing the inverse transformation.
"""
return _ffi_api.Layout_inverse(self)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .plot_layout import plot_layout # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T
def plot_layout(layout: T.Layout,
save_directory="./tmp",
name: str = "layout",
colormap: str = "RdPu",
verbose: bool = False) -> None:
"""
Plot the layout of a buffer.
Parameters
----------
layout : T.Layout
The layout object that describes how indices are mapped.
save_directory : str, optional
The directory where the output images will be saved (default is "./tmp").
name : str, optional
The base name of the output files (default is "layout").
colormap : str, optional
The colormap to use for visualization (default is "RdPu").
verbose : bool, optional
If True, prints additional information about the mapping (default is False).
Returns
-------
None
"""
import os
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Get the input shape of the layout and convert it to a list of integers
input_shape = layout.get_input_shape()
input_shape = [int(var) for var in input_shape]
# Get the total number of threads
num_threads = int(layout.get_thread_size())
import itertools
# Initialize a 2D array to store thread mappings
thread_map = np.zeros(input_shape, dtype=int)
# Iterate over all possible indices in the input shape
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
# If replication is enabled, adjust the index
if layout.replicate_size > 1:
index.insert(0, 0)
# Map the index to a thread ID
thread_id = layout.map_forward_thread(index)
assert len(thread_id) == 1 # Ensure a single-thread mapping
thread_map[idx] = int(thread_id[0]) # Store the thread ID
# Initialize a 2D array to store value mappings
value_map = np.zeros(input_shape, dtype=int)
# Iterate again to map values
for idx in itertools.product(*[range(dim) for dim in input_shape]):
index = list(idx)
if layout.replicate_size > 1:
index.insert(0, 0)
thread_id = layout.map_forward_thread(index)
value_id = layout.map_forward_index(index)
assert len(value_id) == 1 # Ensure a single-value mapping
value_map[idx] = int(value_id[0]) # Store the value ID
# Load the colormap with twice as many colors as the number of threads
cmap = plt.get_cmap(colormap, num_threads * 2)
# Generate a list of colors based on the colormap
raw_colors = [cmap(i) for i in range(num_threads)]
colors = raw_colors.copy()
# Determine the number of rows and columns in the input shape
nrows, ncols = input_shape
plt.figure(figsize=(nrows, ncols)) # Set the figure size
ax = plt.gca() # Get the current axis
font_size = 24 # Set font size for text annotatio
# Iterate through each row and column
for i in range(nrows):
for j in range(ncols):
thread_id = thread_map[i, j] # Get the thread ID
local_id = value_map[i, j] # Get the value ID
if verbose:
print(f"thread_map[{i}, {j}] = {thread_id} value_map[{i}, {j}] = {local_id}")
color = colors[thread_id] # Select color based on thread ID
# Create a rectangle patch for visualization
rect = patches.Rectangle((j, i),
1,
1,
linewidth=0.5,
edgecolor='black',
facecolor=color)
ax.add_patch(rect) # Add the rectangle to the plot
# Add text annotations inside the rectangles
text = f"T{thread_id}\nL{local_id}"
ax.text(
j + 0.5, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)
# Add row labels to the left side of the plot
for i in range(nrows):
text = f"row {i}"
ax.text(-0.75, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)
# Add column labels at the top of the plot
for j in range(ncols):
text = f"col {j}"
ax.text(
j + 0.5,
-0.5,
text,
ha='center',
va='center',
color='black',
fontsize=font_size,
rotation=45)
# Set the plot limits
ax.set_xlim(0, ncols)
ax.set_ylim(0, nrows)
ax.invert_yaxis() # Invert the y-axis for proper visualization
plt.xticks([]) # Remove x-axis ticks
plt.yticks([]) # Remove y-axis ticks
# Create the output directory if it does not exist
tmp_directory = pathlib.Path(save_directory)
if not os.path.exists(tmp_directory):
os.makedirs(tmp_directory)
# Save the figure in multiple formats
plt.tight_layout()
# Save as PDF
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
# Save as PNG
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
# Save as SVG
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment