"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c5c0fa2e2dede71d2797a8bafa85c90f59d311f8"
Commit d55386d1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Debug] Support `T.print` for `fragment` scope (#130)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix

* nas kernel

* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates

- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation

- Update flash attention kernel to support positional embeddings (PE)
- Modify reference implementation to handle PE and group query attention
- Increase default batch size and adjust benchmarking parameters
- Improve kernel performance and readability
- Add einops and torch operations for more flexible tensor manipulation

* Update README.md with corrected Flash MLA Decoding example path

- Modify the example link for Flash MLA Decoding to point to the correct directory
- Ensure accurate navigation to the DeepSeek MLA decoding example

* Refactor Native Sparse Attention Kernel and Improve Utility Functions

This commit introduces several improvements:
- Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py
- Enhanced error handling in loop_partition.cc with more informative error messages
- Updated print.py to support multi-dimensional buffer printing
- Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting
- Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2
- Added shape validation and detailed mismatch information in tensor comparison

* Refactor Code Formatting and Improve Utility Functions

This commit introduces several code formatting and utility improvements:
- Add Ruff linter ignore comment in example_tilelang_nsa.py
- Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks
- Simplify print_flat_buffer_with_condition in print.py
- Refactor torch_assert_close in testing/__init__.py with improved line formatting

* Enhance Buffer Printing Support for Fragment and Shared Memory Buffers

This commit improves the print functionality in print.py by:
- Adding support for printing fragment memory buffers
- Implementing a new print_fragment_buffer_with_condition macro
- Extending print_shared_buffer_with_condition for shared memory buffers
- Updating the generic print function to handle different buffer scopes

* Resolve merge conflict in print.py

Remove merge conflict marker and clean up whitespace in the print module
parent 20bbb91a
...@@ -8,7 +8,7 @@ It includes functionality to print variables, print values in buffers, and condi ...@@ -8,7 +8,7 @@ It includes functionality to print variables, print values in buffers, and condi
from tvm import tir from tvm import tir
from typing import Any from typing import Any
from tilelang.language.kernel import get_thread_bindings from tilelang.language.kernel import get_thread_bindings
from tilelang.language import macro, serial from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.intrinsics.utils import index_to_coordinates from tilelang.intrinsics.utils import index_to_coordinates
...@@ -45,10 +45,10 @@ def print_var_with_condition(condition: tir.PrimExpr, ...@@ -45,10 +45,10 @@ def print_var_with_condition(condition: tir.PrimExpr,
@macro @macro
def print_flat_buffer_with_condition(condition: tir.PrimExpr, def print_shared_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer, buffer: tir.Buffer,
elems: int, elems: int,
msg: str = "") -> tir.PrimExpr: msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
...@@ -68,6 +68,31 @@ def print_flat_buffer_with_condition(condition: tir.PrimExpr, ...@@ -68,6 +68,31 @@ def print_flat_buffer_with_condition(condition: tir.PrimExpr,
buffer[coords]) buffer[coords])
@macro
def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
smem = alloc_shared(buffer.shape, buffer.dtype, "shared")
copy(buffer, smem)
if condition:
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
def print(obj: Any, msg: str = "") -> tir.PrimExpr: def print(obj: Any, msg: str = "") -> tir.PrimExpr:
""" """
A generic print function that handles both TIR buffers and primitive expressions. A generic print function that handles both TIR buffers and primitive expressions.
...@@ -93,18 +118,27 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr: ...@@ -93,18 +118,27 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
# Flatten the buffer for consistent printing. This assumes a 1D flattened buffer. # Flatten the buffer for consistent printing. This assumes a 1D flattened buffer.
buffer = obj buffer = obj
if buffer.scope() == "local.fragment": if buffer.scope() == "local.fragment":
raise NotImplementedError("Printing fragment buffers currently is not supported.") # Get the number of elements in the buffer.
elems = 1
for dim in buffer.shape:
elems *= dim
# Get the number of elements in the buffer. # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
elems = 1 condition = (tx == 0 and ty == 0 and tz == 0)
for dim in buffer.shape: if not msg:
elems *= dim msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_fragment_buffer_with_condition(condition, buffer, elems, msg)
elif buffer.scope() in {"shared", "shared.dyn"}:
# Get the number of elements in the buffer.
elems = 1
for dim in buffer.shape:
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0) condition = (tx == 0 and ty == 0 and tz == 0)
if not msg: if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>" msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_flat_buffer_with_condition(condition, buffer, elems, msg) return print_shared_buffer_with_condition(condition, buffer, elems, msg)
elif isinstance(obj, tir.PrimExpr): elif isinstance(obj, tir.PrimExpr):
if not msg: if not msg:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment