Commit 20bbb91a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Remove buffer flatten when debug print a shared buffer (#129)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix

* nas kernel

* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates

- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation

- Update flash attention kernel to support positional embeddings (PE)
- Modify reference implementation to handle PE and group query attention
- Increase default batch size and adjust benchmarking parameters
- Improve kernel performance and readability
- Add einops and torch operations for more flexible tensor manipulation

* Update README.md with corrected Flash MLA Decoding example path

- Modify the example link for Flash MLA Decoding to point to the correct directory
- Ensure accurate navigation to the DeepSeek MLA decoding example

* Refactor Native Sparse Attention Kernel and Improve Utility Functions

This commit introduces several improvements:
- Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py
- Enhanced error handling in loop_partition.cc with more informative error messages
- Updated print.py to support multi-dimensional buffer printing
- Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting
- Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2
- Added shape validation and detailed mismatch information in tensor comparison

* Refactor Code Formatting and Improve Utility Functions

This commit introduces several code formatting and utility improvements:
- Add Ruff linter ignore comment in example_tilelang_nsa.py
- Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks
- Simplify print_flat_buffer_with_condition in print.py
- Refactor torch_assert_close in testing/__init__.py with improved line formatting
parent 0d873fcf
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from reference import naive_nsa
import tilelang
......@@ -40,77 +40,6 @@ def native_sparse_attention(batch,
def kernel_func(block_S, block_T, num_stages, threads):
@T.macro
def MMA0(
K: T.Buffer(kv_shape, dtype),
Q_shared: T.Buffer([G, BK], dtype),
K_shared: T.Buffer([BS, BK], dtype),
acc_s: T.Buffer([G, BS], accum_dtype),
i_s: T.int32,
i_b: T.int32,
i_h: T.int32,
i_t: T.int32,
):
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s * BS + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(kv_shape, dtype),
V_shared: T.Buffer([G, BV], dtype),
acc_s_cast: T.Buffer([G, BS], dtype),
acc_o: T.Buffer([G, BV], accum_dtype),
i_s: T.int32,
i_b: T.int32,
i_h: T.int32,
):
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([G, BS], accum_dtype),
acc_s_cast: T.Buffer([G, BS], dtype),
scores_max: T.Buffer([G], accum_dtype),
scores_max_prev: T.Buffer([G], accum_dtype),
scores_scale: T.Buffer([G], accum_dtype),
scores_sum: T.Buffer([G], accum_dtype),
logsum: T.Buffer([G], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([G, BV], accum_dtype),
scores_scale: T.Buffer([G], accum_dtype),
):
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(q_shape, dtype),
......@@ -147,11 +76,51 @@ def native_sparse_attention(batch,
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i]
if i_s <= i_t:
MMA0(K, Q_shared, K_shared, acc_s, i_s, i_b, i_h, i_t)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, i_s, i_b, i_h)
# Q * K
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s * BS + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
......
......@@ -125,7 +125,9 @@ public:
PrimExpr flattened = 0;
for (size_t i = 0; i < loop_vars_.size(); i++) {
auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
ICHECK(ext_ptr);
ICHECK(ext_ptr)
<< "Loop partitioner only works with constant loop sizes, but got "
<< loop_vars_[i]->dom->extent;
int extent = *ext_ptr;
loop_size_full *= extent;
flattened = flattened * extent + loop_vars_[i]->var;
......
......@@ -30,6 +30,7 @@
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
......@@ -205,9 +206,13 @@ private:
}
PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (!op->op.same_as(builtin::ptx_ldmatrix()) &&
!op->op.same_as(builtin::mma_store())) {
return Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
builtin::mma_store()};
if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
ptx_instructions.end()) {
auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
return call;
} else {
is_ptx_ = true;
}
......@@ -252,6 +257,7 @@ private:
if (is_ptx_) {
return load;
}
if (buffer_remap_.count(load->buffer)) {
auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
auto new_buffer = buffer_remap_[load->buffer];
......
......@@ -9,6 +9,7 @@ from tvm import tir
from typing import Any
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import macro, serial
from tilelang.intrinsics.utils import index_to_coordinates
@macro
......@@ -62,7 +63,9 @@ def print_flat_buffer_with_condition(condition: tir.PrimExpr,
if condition:
# Iterate through the buffer elements and print each one.
for i in serial(elems):
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[i])
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords])
def print(obj: Any, msg: str = "") -> tir.PrimExpr:
......@@ -88,13 +91,14 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
tx, ty, tz = get_thread_bindings()
# Flatten the buffer for consistent printing. This assumes a 1D flattened buffer.
buffer = obj.get_flattened_buffer()
buffer = obj
if buffer.scope() == "local.fragment":
raise NotImplementedError("Printing fragment buffers currently is not supported.")
assert len(buffer.shape) == 1, "Buffer must be flattened into a 1D shape."
# Get the number of elements in the buffer.
elems = buffer.shape[-1]
elems = 1
for dim in buffer.shape:
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0)
......
......@@ -18,7 +18,7 @@ def main():
def torch_assert_close(tensor_a,
tensor_b,
rtol=1e-2,
atol=1e-3,
atol=1e-2,
max_mismatched_ratio=0.001,
verbose=False):
"""
......@@ -46,6 +46,9 @@ def torch_assert_close(tensor_a,
"""
import torch
# Assert shapes are the same
assert tensor_a.shape == tensor_b.shape, f"Tensor shapes must be the same, but got {tensor_a.shape} and {tensor_b.shape}"
# Compute the absolute difference between the two tensors
diff = torch.abs(tensor_a - tensor_b)
......@@ -69,12 +72,29 @@ def torch_assert_close(tensor_a,
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})")
# Check if the number of mismatched elements exceeds the allowed threshold
# If there are mismatched elements, print the first mismatch
if num_mismatched > 0:
# Find the first mismatch index
flat_idx = torch.argmax(mismatched.view(-1).int()).item()
idx = np.unravel_index(flat_idx, tensor_a.shape)
idx = [int(i) for i in idx]
a_val = tensor_a.view(-1)[flat_idx].item()
b_val = tensor_b.view(-1)[flat_idx].item()
abs_diff = abs(a_val - b_val)
rel_diff = abs_diff / (abs(b_val) + 1e-12)
mismatch_info = (f"\nFirst mismatch at index {idx}: "
f"lhs={a_val:.6f}, rhs={b_val:.6f}, "
f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}")
else:
mismatch_info = ""
# Modify the exception information
if num_mismatched > max_allowed_mismatched:
raise AssertionError(
f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} "
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%). "
f"Greatest absolute difference: {diff.max().item()}, "
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)."
f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else:
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment