Commit d55386d1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Debug] Support `T.print` for `fragment` scope (#130)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix

* nas kernel

* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates

- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation

- Update flash attention kernel to support positional embeddings (PE)
- Modify reference implementation to handle PE and group query attention
- Increase default batch size and adjust benchmarking parameters
- Improve kernel performance and readability
- Add einops and torch operations for more flexible tensor manipulation

* Update README.md with corrected Flash MLA Decoding example path

- Modify the example link for Flash MLA Decoding to point to the correct directory
- Ensure accurate navigation to the DeepSeek MLA decoding example

* Refactor Native Sparse Attention Kernel and Improve Utility Functions

This commit introduces several improvements:
- Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py
- Enhanced error handling in loop_partition.cc with more informative error messages
- Updated print.py to support multi-dimensional buffer printing
- Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting
- Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2
- Added shape validation and detailed mismatch information in tensor comparison

* Refactor Code Formatting and Improve Utility Functions

This commit introduces several code formatting and utility improvements:
- Add Ruff linter ignore comment in example_tilelang_nsa.py
- Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks
- Simplify print_flat_buffer_with_condition in print.py
- Refactor torch_assert_close in testing/__init__.py with improved line formatting

* Enhance Buffer Printing Support for Fragment and Shared Memory Buffers

This commit improves the print functionality in print.py by:
- Adding support for printing fragment memory buffers
- Implementing a new print_fragment_buffer_with_condition macro
- Extending print_shared_buffer_with_condition for shared memory buffers
- Updating the generic print function to handle different buffer scopes

* Resolve merge conflict in print.py

Remove merge conflict marker and clean up whitespace in the print module
parent 20bbb91a
......@@ -8,7 +8,7 @@ It includes functionality to print variables, print values in buffers, and condi
from tvm import tir
from typing import Any
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import macro, serial
from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.intrinsics.utils import index_to_coordinates
......@@ -45,7 +45,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@macro
def print_flat_buffer_with_condition(condition: tir.PrimExpr,
def print_shared_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
......@@ -68,6 +68,31 @@ def print_flat_buffer_with_condition(condition: tir.PrimExpr,
buffer[coords])
@macro
def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
smem = alloc_shared(buffer.shape, buffer.dtype, "shared")
copy(buffer, smem)
if condition:
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
def print(obj: Any, msg: str = "") -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
......@@ -93,8 +118,17 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
# Flatten the buffer for consistent printing. This assumes a 1D flattened buffer.
buffer = obj
if buffer.scope() == "local.fragment":
raise NotImplementedError("Printing fragment buffers currently is not supported.")
# Get the number of elements in the buffer.
elems = 1
for dim in buffer.shape:
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0)
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_fragment_buffer_with_condition(condition, buffer, elems, msg)
elif buffer.scope() in {"shared", "shared.dyn"}:
# Get the number of elements in the buffer.
elems = 1
for dim in buffer.shape:
......@@ -104,7 +138,7 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
condition = (tx == 0 and ty == 0 and tz == 0)
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_flat_buffer_with_condition(condition, buffer, elems, msg)
return print_shared_buffer_with_condition(condition, buffer, elems, msg)
elif isinstance(obj, tir.PrimExpr):
if not msg:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment