Commit e2d25ba8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Warp Specialize] Implicit Warp Specialize Programing Model (#605)

* [Enhancement] Improve memory access condition checks in GlobalMemChecker

- Updated the condition checks in the GlobalMemChecker to utilize symbolic bounds in the CanProve method, enhancing the accuracy of memory access validations.
- This change ensures that both upper and lower bound conditions are evaluated with improved proof strength, contributing to more robust memory access analysis.

* lintfix

* [Enhancement] Add legality checks for shared memory and global range in LowerBulkCopy

- Implemented checks to ensure that the shared memory range and global range are legal during the bulk copy operation.
- Added assertions to validate that the extents of global and shared ranges match, improving the robustness of memory access validation in the LowerBulkCopy function.

* [Refactor] Update barrier and clear operations in warp specialization examples

- Replaced `mbarrier_wait_parity` and `mbarrier_arrive` with `barrier_wait` and `barrier_arrive` for improved clarity and consistency in synchronization.
- Adjusted the order of `clear` operations for local fragments in `example_warp_specialize_gemm_copy_1_gemm_0` to enhance parallel execution efficiency.

* [Enhancement] Implement thread partial synchronization and improve shared memory allocation handling

- Added support for thread partial barrier synchronization in CUDA, allowing for more flexible thread management.
- Enhanced the `MergeSharedMemoryAllocations` function to accept alignment bytes, improving memory allocation efficiency based on target requirements.
- Updated the `Lower` methods in `Copy` and `Fill` classes to include conditional predicates for thread execution, ensuring better control over thread behavior.
- Refactored the `print` function to include warp group and warp IDs for more detailed debugging output.
- Improved the handling of dynamic shared memory allocations in the `LowerAndLegalize` function to align with target-specific requirements.

* [Enhancement] Add support for disabling TMA in Copy operations

- Introduced a new `disable_tma` parameter in the `Copy` class to control thread memory access behavior.
- Updated the `Lower` method to conditionally execute bulk copy operations based on the `disable_tma` flag.
- Enhanced the `copy` function to accept the `disable_tma` argument, allowing for more flexible memory copy operations.
- Improved handling of `coalesced_width` to ensure it defaults to -1 when not provided, enhancing robustness in memory operations.

* [Refactor] Clean up whitespace and formatting in multiple files

- Removed unnecessary blank lines and adjusted line breaks for improved code readability in `example_mla_decode.py`, `example_warp_specialize_gemm_copy_gemm_0_1.py`, `phase.py`, and `copy.py`.
- Ensured consistent formatting across functions to enhance maintainability and clarity of the codebase.

* [Enhancement] Refactor flash attention implementation for improved performance and configurability

- Split the shared memory allocations for query and key-value pairs to optimize memory usage.
- Introduced command-line arguments for batch size, number of heads, and dimensions, enhancing flexibility in running the example.
- Updated kernel execution parameters to improve thread management and synchronization.
- Enhanced the overall structure of the flash attention function for better readability and maintainability.

* fix

* Update layout inference in ParallelOp to account for thread bounds; remove debug print in OptimizeForTarget

* Refactor barrier handling and update example configurations

- Replaced commented-out barrier creation with new barrier allocation in GEMM example.
- Updated kernel configuration in warp specialization example to include async copy settings.
- Enhanced barrier management in the phase optimization process to improve synchronization handling.
- Introduced new barrier allocation function for better memory management in shared contexts.

* Refactor barrier handling in LowerAndLegalize and OptimizeForTarget

- Reintroduced barrier lowering in OptimizeForTarget to enhance synchronization.
- Removed commented-out barrier lowering in LowerAndLegalize for cleaner code.
- Added exit() call in OptimizeForTarget to halt execution after barrier lowering.

* Enhance CMake configuration and clean up example scripts

- Enabled compile command export in CMakeLists.txt for better build integration.
- Removed unnecessary print statement in the warp specialization example.
- Cleaned up commented-out code in GEMM example for improved readability.
- Updated barrier handling in shared memory allocation transformations for better synchronization.

* Refactor barrier handling in warp specialization examples

- Replaced commented-out mbarrier code with new barrier allocation using T.alloc_barrier for improved synchronization.
- Updated barrier wait and arrive calls to align with the new allocation method across multiple example scripts.
- Enhanced code readability by removing unnecessary comments and ensuring consistent barrier management.

* Update lower_shared_barrier.cc

* Update phase.py

* Update warp specialization example and Cython wrapper

- Removed commented-out pass configuration options in the warp specialization example for clarity.
- Added functionality to write the generated kernel source to a file named "kernel.cu".
- Enhanced Cython wrapper to support boolean type conversion for improved type handling.

* Add storage synchronization call in shared barrier transformation

- Introduced a new evaluation statement to call the TVM storage sync function with "shared" as an argument, enhancing synchronization in the shared barrier handling process.

* remove debug files

* Remove kernel source output to file in warp specialization example

* remove comments

* Refactor tensor handling and update test execution in TileLang

- Changed `Buffer` to `Tensor` in `customize.py` for better type consistency.
- Updated `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to use `tir.BufferLoad` instead of `BufferLoad`.
- Commented out the main testing function in `test_tilelang_language_reshape.py` and replaced it with a direct call to `run_reshape_smem` for streamlined testing.
- Removed unnecessary NVCC compiler flags in `libgen.py` to reduce verbosity.

* Update test_tilelang_language_reshape.py
parent 68989d80
......@@ -9,6 +9,9 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type")
endif()
# Enable compile command export
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Define a custom macro for globbing files with conditional CONFIGURE_DEPENDS
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.12.0")
macro(tilelang_file_glob glob variable)
......
......@@ -24,7 +24,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
......@@ -39,25 +39,28 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
Q_pe_shared,
K_pe_shared,
......@@ -81,7 +84,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
......@@ -93,7 +96,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
batch, heads // min(block_H, kv_group_num), num_split,
threads=256) as (bid, hid, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
......@@ -109,15 +113,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -126,8 +130,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
......@@ -156,9 +160,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
bz, :])
@T.macro
def combine(
......@@ -166,7 +171,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
......@@ -182,20 +187,20 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
Output[bz, hid, i] = o_accum_local[i]
@T.prim_func
def main_split(
......@@ -273,7 +278,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
......@@ -290,7 +295,7 @@ def main():
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
......
# use default stage 1 template, not the optimal
# schedule, please checkout examples/deepseek_mla
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
tilelang.disable_cache()
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
......@@ -15,6 +16,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
h_dim = dim // 2
@T.macro
def flash_attn(
......@@ -24,81 +26,312 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=384) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
S_shared = K_pe_shared_0
S_shared_ = K_pe_shared_1
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
scores_max = T.alloc_shared([block_H], accum_dtype)
scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
logsum_0 = T.alloc_fragment([block_H], accum_dtype)
logsum_1 = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_shared([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
})
T.create_list_of_mbarrier(128, 128, 256, 128)
loop_range = T.ceildiv(seqlen_kv, block_N)
with T.ws(2):
T.dec_max_nreg(24)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(barrier_id=3)
kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128)
k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128)
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
k_shared_1_l_free_barrier = T.alloc_barrier(arrive_count=128)
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(q_shared_ready_barrier)
T.barrier_wait(q_shared_ready_barrier, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, (block_N * 2))
if tx < 128:
T.fill(acc_o_l, 0)
T.fill(logsum_0, 0)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
for k in T.serial(loop_range):
T.barrier_wait(barrier_id=(k % 1) + 2, parity=(k % 2) ^ 1)
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.barrier_arrive(k % 1)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.barrier_arrive(k % 1 + 1)
with T.ws(0, 1):
T.inc_max_nreg(240)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.barrier_wait(3, 0)
T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
T.gemm(
Q_shared_l,
KV_shared_0_l,
acc_s_0,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_0_r,
acc_s_0,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1)
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
T.gemm(
Q_pe_shared,
K_pe_shared_0,
acc_s_0,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1)
T.wait_wgmma(0)
# Step 3.
T.copy(scores_max, scores_max_0)
T.copy(scores_max_0, scores_max_prev_0)
T.fill(scores_max_0, -T.infinity(accum_dtype))
T.reduce_max(acc_s_0, scores_max_0, dim=1, clear=False)
T.copy(scores_max_0, scores_max)
# Step 4.
for i, j in T.Parallel(block_H, block_N):
acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
for i in T.Parallel(block_H):
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale -
scores_max[i] * scale)
T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
# Step 5.
T.copy(acc_s_0, S_shared)
for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] *= scores_scale_0[i]
for i in T.Parallel(block_H):
logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]
# Step 6.
T.gemm(S_shared, KV_shared_0_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol)
T.barrier_arrive(score_max_0_ready_barrier)
T.barrier_wait(scale_1_ready_barrier, k % 2)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
cur_kv_head, :h_dim], KV_shared_0_l)
T.barrier_arrive(kv_shared_0_l_is_ready)
# Step 11.
for i, j in T.Parallel(block_H, block_N):
S_shared_[i, j] = acc_s_0[i, j] * scores_scale_1[i]
T.barrier_arrive(p0_1_1_ready_barrier)
# Step 13.
for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] *= scores_scale_1[i]
for i in T.Parallel(block_H):
logsum_0[i] = logsum_0[i] * scores_scale_1[i]
T.barrier_wait(s_shared_ready_barrier, k % 2)
# Step 14.
T.gemm(S_shared, KV_shared_1_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol)
T.barrier_arrive(k_pe_shared_0_free_barrier)
T.barrier_arrive(k_shared_1_l_free_barrier)
if k < loop_range - 1:
T.barrier_wait(k_shared_1_l_free_barrier, k % 2)
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.barrier_wait(k_pe_shared_1_free_barrier, k % 2)
T.copy(
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
T.copy(logsum_0, logsum)
T.barrier_arrive(lse_0_ready_barrier)
T.barrier_wait(lse_1_ready_barrier, 0)
for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] /= logsum[i]
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid,
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
else:
T.fill(acc_o_r, 0)
T.fill(logsum_1, 0)
T.copy(KV[bid, :block_N, cur_kv_head, :h_dim], KV_shared_0_l)
T.barrier_arrive(kv_shared_0_l_is_ready)
T.copy(KV[bid, :block_N, cur_kv_head, h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready)
T.copy(K_pe[bid, :block_N, cur_kv_head, :], K_pe_shared_0)
T.barrier_arrive(kv_shared_0_pe_is_ready)
for k in T.serial(loop_range):
T.clear(acc_s)
T.barrier_wait(barrier_id=k % 1, parity=(k // 1) % 2)
# Step 2.
T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
T.gemm(
Q_shared,
KV_shared,
acc_s,
Q_shared_l,
KV_shared_1_l,
acc_s_1,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.barrier_wait(barrier_id=k % 1 + 1, parity=(k // 1) % 2)
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_1_r,
acc_s_1,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1)
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
K_pe_shared_1,
acc_s_1,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1)
T.wait_wgmma(0)
# Step 7.
T.barrier_wait(score_max_0_ready_barrier, k % 2)
T.copy(scores_max, scores_max_prev_1)
T.fill(scores_max_1, -T.infinity(accum_dtype))
T.reduce_max(acc_s_1, scores_max_1, dim=1, clear=False)
T.copy(scores_max_1, scores_max)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale -
scores_max[i] * scale)
# Step 8.
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
acc_s_1[i, j] = T.exp2(acc_s_1[i, j] * scale - scores_max[i] * scale)
# Step 9.
T.reduce_sum(acc_s_1, scores_sum_1, dim=1)
for i, j in T.Parallel(block_H, h_dim):
acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.barrier_arrive(barrier_id=k % 1 + 2)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[
i] + scores_sum_1[i]
T.barrier_arrive(scale_1_ready_barrier)
# Step 10. compute O1 with KV_shared_1_rd
T.copy(acc_s_1, S_shared)
T.barrier_arrive(s_shared_ready_barrier)
T.gemm(
S_shared,
KV_shared_1_r,
acc_o_r,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.barrier_wait(p0_1_1_ready_barrier, k % 2)
# Step 12.
T.gemm(S_shared_, KV_shared_0_r, acc_o_r, policy=T.GemmWarpPolicy.FullCol)
T.barrier_arrive(k_pe_shared_1_free_barrier)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready)
T.barrier_wait(k_pe_shared_0_free_barrier, k % 2)
T.copy(
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
K_pe_shared_0)
T.barrier_arrive(kv_shared_0_pe_is_ready)
T.barrier_wait(lse_0_ready_barrier, 0)
for i in T.Parallel(block_H):
logsum[i] += logsum_1[i]
T.barrier_arrive(lse_1_ready_barrier)
for i, j in T.Parallel(block_H, h_dim):
acc_o_r[i, j] /= logsum[i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
h_dim:])
@T.prim_func
def main_no_split(
......@@ -159,13 +392,15 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
def main():
batch = 128
heads = 128
kv_heads = 1
kv_ctx = 8192
dim = 512
pe_dim = 64
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
......@@ -173,9 +408,8 @@ def main():
BLOCK_H = 64
num_split = 1
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
......
......@@ -25,23 +25,22 @@ def matmul_warp_specialize_copy_0_gemm_1(M,
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
data_is_ready = T.alloc_barrier(arrive_count=128)
compute_is_done = T.alloc_barrier(arrive_count=128)
with T.ws(1):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
with T.ws(0):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.barrier_wait(compute_is_done, (ko + 1) % 2)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
T.barrier_arrive(data_is_ready)
with T.ws(1):
T.mbarrier_wait_parity(0, ko & 1)
T.barrier_wait(data_is_ready, ko % 2)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.barrier_arrive(compute_is_done)
with T.ws(1):
T.copy(C_local, C[by * block_M, bx * block_N])
......
......@@ -26,22 +26,22 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
data_is_ready = T.alloc_barrier(arrive_count=128)
compute_is_done = T.alloc_barrier(arrive_count=128)
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.barrier_wait(compute_is_done, (ko + 1) % 2)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
T.barrier_arrive(data_is_ready)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.barrier_wait(data_is_ready, ko % 2)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.barrier_arrive(compute_is_done)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
......
......@@ -39,8 +39,10 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
C_local_g0 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype)
C_local_g1 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype)
T.clear(C_local_g0)
T.clear(C_local_g1)
with T.ws(1):
T.clear(C_local_g1)
with T.ws(0):
T.clear(C_local_g0)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
......@@ -51,8 +53,10 @@ def matmul_warp_specialize_copy_1_gemm_0(M,
T.copy(B[ko * block_K, bx * block_N + block_N // warp_group_num], B_shared_g0)
T.gemm(A_shared, B_shared_g0, C_local_g0)
T.copy(C_local_g1, C[by * block_M, bx * block_N])
T.copy(C_local_g0, C[by * block_M, bx * block_N + block_N // warp_group_num])
with T.ws(1):
T.copy(C_local_g1, C[by * block_M, bx * block_N])
with T.ws(0):
T.copy(C_local_g0, C[by * block_M, bx * block_N + block_N // warp_group_num])
return main
......
......@@ -20,21 +20,22 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
data_is_ready = T.alloc_barrier(arrive_count=128)
compute_is_done = T.alloc_barrier(arrive_count=128)
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.barrier_wait(compute_is_done, (ko + 1) % 2)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
T.barrier_arrive(data_is_ready)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.barrier_wait(data_is_ready, ko % 2)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.barrier_arrive(compute_is_done)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
......
......@@ -36,7 +36,14 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]);
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
auto disable_tma = Downcast<Bool>(args[3]);
this->disable_tma = disable_tma;
}
}
......@@ -159,9 +166,12 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (ldsm_stmt.defined())
return ldsm_stmt;
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
if (!disable_tma) {
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined())
return bulk_copy_stmt;
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
......@@ -191,7 +201,6 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
......@@ -348,6 +357,13 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
For for_node =
For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
auto range = T.thread_bounds;
if (range.defined()) {
auto thread_var = T.thread_var;
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
}
return for_node;
}
......@@ -464,6 +480,10 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
......
......@@ -42,6 +42,7 @@ protected:
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
Bool disable_tma = Bool(false);
std::unique_ptr<ParallelOp> par_op_;
};
......
......@@ -216,7 +216,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
......
......@@ -220,6 +220,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;
......@@ -231,7 +232,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< ">::run_hopper";
} else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run";
<< reducing_threads << ", " << (*scale) << ", "
<< (T.thread_bounds->min) << ">::run";
}
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
......
......@@ -558,12 +558,26 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
}
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
const std::string &sync = op->args[0].as<StringImmNode>()->value;
auto args = op->args;
const std::string &sync = args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
if (args.size() == 1) {
this->stream << "__syncthreads();\n";
} else if (args.size() == 2) {
auto barrier_id = args[1].as<IntImmNode>()->value;
this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n";
} else if (args.size() == 3) {
auto barrier_id = args[1].as<IntImmNode>()->value;
auto thread_count = args[2].as<IntImmNode>()->value;
this->stream << "tl::__sync_thread_partial<" << barrier_id << ", "
<< thread_count << ">();\n";
} else {
LOG(FATAL) << "Invalid number of arguments for storage sync: "
<< args.size();
}
} else if (sync == "global") {
if (!need_global_barrier_) {
need_global_barrier_ = true;
......
......@@ -234,4 +234,11 @@ template <int y = 1, typename T> TL_DEVICE T pow_of_int(T x) {
return result;
}
// Thread partial barrier synchronization
// https://docs.nvidia.com/cuda/parallel-thread-execution/#memory-consistency-model
template <int barrier_id = 0, int thread_count = 0>
TL_DEVICE void __sync_thread_partial() {
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
}
} // namespace tl
......@@ -22,7 +22,8 @@ struct MinOp {
}
};
template <class Reducer, int threads, int scale, int all_threads = threads>
template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32 or
......@@ -32,9 +33,9 @@ struct AllReduce {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
__syncthreads();
red_buf[threadIdx.x] = x;
red_buf[threadIdx.x - thread_offset] = x;
__syncthreads();
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
}
......@@ -53,7 +54,7 @@ struct AllReduce {
red_buf[threadIdx.x] = x;
// TODO(lei): maybe we can merge the two bar.sync into one?
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
}
......
......@@ -26,6 +26,12 @@ public:
return checker.is_valid_;
}
static IterVar GetThreadVar(const Stmt &body) {
ThreadTagChecker checker;
checker(body);
return checker.thread_var_;
}
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
......@@ -45,6 +51,9 @@ private:
if (op->kind == ForKind::kThreadBinding) {
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
if (thread_tag == "threadIdx.x") {
thread_var_ = Downcast<IterVar>(op->thread_binding);
}
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z) {
......@@ -57,7 +66,7 @@ private:
}
StmtExprVisitor::VisitStmt_(op);
}
IterVar thread_var_;
bool is_valid_ = true;
};
......
......@@ -243,6 +243,8 @@ public:
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier())) {
has_create_list_of_mbarrier = true;
} else if (call->op.same_as(builtin::ptx_init_barrier_thread_count())) {
has_create_list_of_mbarrier = true;
}
}
});
......
/*!
* \file lower_shared_barrier.cc
* \brief Convert shared.barrier buffers to plain shared + ptx init.
*/
#include "tvm/ir/type.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
class SharedBarrierRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
SharedBarrierRewriter rewriter;
return rewriter(body);
}
private:
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
Array<Buffer> barrier_buffers;
for (auto [data, buffer] : buffer_map_) {
const auto *ptr_type =
buffer->data->type_annotation.as<PointerTypeNode>();
auto storage_scope = ptr_type->storage_scope;
ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
if (storage_scope == "shared.barrier") {
barrier_buffers.push_back(buffer);
}
}
if (barrier_buffers.size() == 0) {
return StmtExprMutator::VisitStmt_(op);
}
ICHECK(thread_var_.defined()) << "thread_var_ is not defined";
for (auto buffer : barrier_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
/*
Transform the barrier buffers to new allocations
transform:
data_is_ready = T.alloc_buffer((128,), "uint64", scope="shared.barrier")
compute_is_done = T.alloc_buffer((128,), "uint64",
scope="shared.barrier")
into:
data_is_ready = T.alloc_buffer((1,), "uint64", scope="shared")
compute_is_done = T.alloc_buffer((1,), "uint64", scope="shared")
if tx == 0:
T.ptx_init_barrier_thread_count(data_is_ready[0], 128)
T.ptx_init_barrier_thread_count(compute_is_done[0], 128)
*/
// 1. create new data vars
Array<Var> new_data_vars;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
auto ptr_type = data->type_annotation.as<PointerTypeNode>();
auto new_data =
Var(data->name_hint, PointerType(ptr_type->element_type, "shared"));
var_remap_.Set(data, new_data);
new_data_vars.push_back(new_data);
}
// 2. create new buffers
Array<Buffer> new_buffers;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
ICHECK(var_remap_.find(data) != var_remap_.end())
<< "data not found in var_remap_";
auto new_data = var_remap_.at(data);
auto new_buffer = Buffer(new_data, buffer->dtype, Array<PrimExpr>({1}),
Array<PrimExpr>({1}), PrimExpr(0), buffer->name,
buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type);
new_buffers.push_back(new_buffer);
buffer_remap_.Set(buffer, new_buffer);
}
// remove the barrier buffers
alloc_buffers.MutateByApply([this](Buffer buf) {
if (buffer_remap_.find(buf) != buffer_remap_.end()) {
return buffer_remap_.at(buf);
}
return buf;
});
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
} else {
return StmtExprMutator::VisitStmt_(op);
}
// 3. create init calls for new buffers
Array<Stmt> init_mbarrier_calls_;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
auto old_buffer = buffer_data_to_buffer_.at(data);
auto new_buffer = buffer_remap_.at(old_buffer);
auto count = old_buffer->shape[0];
auto call =
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{BufferLoad(new_buffer, {0}), PrimExpr(count)});
init_mbarrier_calls_.push_back(Evaluate(call));
}
Array<Stmt> new_body;
new_body.push_back(IfThenElse(EQ(thread_var_->var, 0),
SeqStmt(init_mbarrier_calls_), Stmt()));
new_body.push_back(
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")})));
new_body.push_back(block->body);
block.CopyOnWrite()->body = SeqStmt(new_body);
return StmtExprMutator::VisitStmt_(block.get());
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto buffer = load->buffer;
if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, load->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferLoad(new_buffer, load->indices);
}
return load;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto buffer = store->buffer;
if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[store->buffer];
return BufferStore(new_buffer, store->value, store->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferStore(new_buffer, store->value, store->indices);
}
return store;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
return StmtExprMutator::VisitStmt_(op);
}
// This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_;
Map<Var, Var> var_remap_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_;
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
};
PrimFunc LowerSharedBarrier(PrimFunc f) {
SharedBarrierRewriter rewriter;
f.CopyOnWrite()->body = rewriter.Rewrite(f->body);
return f;
}
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerSharedBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return tl::LowerSharedBarrier(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerSharedBarrier")
.set_body_typed(LowerSharedBarrier);
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -323,9 +323,9 @@ public:
explicit SharedMemoryRewriter(
const std::unordered_map<const VarNode *, const AllocateNode *>
&shmem_allocs,
bool is_dynamic = true, bool verbose = false)
: is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs}, verbose_{
verbose} {
bool is_dynamic = true, bool verbose = false, int align_bytes = 0)
: is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs}, verbose_{verbose},
align_bytes_{align_bytes} {
if (!is_dynamic) {
merged_buf_var_ =
Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared"));
......@@ -370,6 +370,7 @@ private:
const AllocateNode *alloc = shmem_allocs_[buffer];
align[i] =
std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes());
align[i] = std::max(align[i], align_bytes_);
}
}
}
......@@ -980,6 +981,8 @@ private:
// Whether enable verbose logging.
bool verbose_{false};
// The alignment bytes for the merged buffer
int align_bytes_{16};
// The var for the merged buffer
Var merged_buf_var_{"buf_dyn_shmem",
PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
......@@ -1008,17 +1011,18 @@ private:
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
bool enable_aggressive_merge,
bool verbose = false) {
int align_bytes = 16, bool verbose = false) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose);
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose,
align_bytes);
rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
verbose);
verbose, align_bytes);
rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
......@@ -1029,9 +1033,10 @@ using namespace tir::transform;
namespace transform {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m,
PassContext ctx) {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
int align_bytes = 16) {
auto pass_func = [enable_aggressive_merge,
align_bytes](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations =
......@@ -1040,7 +1045,7 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
auto *n = f.CopyOnWrite();
n->body = tl::MergeSharedMemoryAllocations(
std::move(n->body), merge_static_smem, enable_aggressive_merge,
debug_merge_shared_memory_allocations);
align_bytes, debug_merge_shared_memory_allocations);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",
......
......@@ -31,13 +31,49 @@
#include <unordered_set>
#include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
struct ThreadBoundKey {
int64_t tx_min, tx_max, ty_min, ty_max, tz_min, tz_max;
bool operator==(const ThreadBoundKey &other) const {
return tx_min == other.tx_min && tx_max == other.tx_max &&
ty_min == other.ty_min && ty_max == other.ty_max &&
tz_min == other.tz_min && tz_max == other.tz_max;
}
};
namespace std {
template <> struct hash<ThreadBoundKey> {
size_t operator()(const ThreadBoundKey &k) const {
size_t h = std::hash<int64_t>()(k.tx_min);
h = h * 31 + std::hash<int64_t>()(k.tx_max);
h = h * 31 + std::hash<int64_t>()(k.ty_min);
h = h * 31 + std::hash<int64_t>()(k.ty_max);
h = h * 31 + std::hash<int64_t>()(k.tz_min);
h = h * 31 + std::hash<int64_t>()(k.tz_max);
return h;
}
};
} // namespace std
namespace tvm {
namespace tl {
// There are 16 Named Barriers provided by Hardware starting in Hopper
// Their IDs are in the range 0-15
// Number of threads syncing using the barrier must be a multiple of warp-size
// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads)
// may use it and conflict with other uses.
enum class ReservedNamedBarriers {
kSyncThreads = 0,
kReduce_0 = 1,
kReduce_1 = 2,
kFirstUsedBarrier = kReduce_1 + 1
};
using namespace tir;
using arith::IRMutatorWithAnalyzer;
class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
public:
......@@ -527,15 +563,164 @@ private:
PrimExpr is_lead_;
};
class ThreadPartialSyncRewriter : public IRMutatorWithAnalyzer {
public:
static Stmt Rewrite(Stmt stmt) {
arith::Analyzer analyzer;
ThreadPartialSyncRewriter rewriter(&analyzer);
return rewriter(std::move(stmt));
}
private:
explicit ThreadPartialSyncRewriter(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = nullptr;
if (op->value->IsInstance<CallNode>()) {
call = static_cast<const CallNode *>(op->value.get());
if (call->op.same_as(builtin::tvm_storage_sync())) {
const auto &args = call->args;
ICHECK(args.size() > 0);
const auto *scope_node = args[0].as<StringImmNode>();
ICHECK(scope_node != nullptr);
const std::string &scope = scope_node->value;
if (args.size() != 1 || (scope != "shared" && scope != "shared.dyn")) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
return ProcessSharedSync(call, scope);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt ProcessSharedSync(const CallNode *op, const std::string &scope) {
// Get thread bounds
auto bound_tx = analyzer_->const_int_bound(tx_);
auto bound_ty = analyzer_->const_int_bound(ty_);
auto bound_tz = analyzer_->const_int_bound(tz_);
// Check if all threads are participating (full extent)
if (IsFullThreadExtent(tx_, bound_tx) &&
IsFullThreadExtent(ty_, bound_ty) &&
IsFullThreadExtent(tz_, bound_tz)) {
return Evaluate(IRMutatorWithAnalyzer::VisitExpr_(op));
}
// Calculate thread extents
auto extent_tx = CalculateThreadExtent(tx_, bound_tx);
auto extent_ty = CalculateThreadExtent(ty_, bound_ty);
auto extent_tz = CalculateThreadExtent(tz_, bound_tz);
// Create or get barrier info
ThreadBoundKey key{bound_tx->min_value, bound_tx->max_value,
bound_ty->min_value, bound_ty->max_value,
bound_tz->min_value, bound_tz->max_value};
auto [barrier_id, thread_count] =
GetOrCreateBarrier(key, extent_tx, extent_ty, extent_tz);
if (thread_count % 32 != 0) {
// TODO(lei): This is a workaround for the case where the thread count is
// not a multiple of 32. we should enhance the pass to analysis index
// instead of buffer expression etc.
return Stmt();
}
// Create new sync call with barrier info
Array<PrimExpr> new_args = {StringImm(scope),
IntImm(DataType::Int(32), barrier_id),
IntImm(DataType::Int(32), thread_count)};
return Evaluate(Call(op->dtype, op->op, new_args));
}
std::pair<size_t, size_t> GetOrCreateBarrier(const ThreadBoundKey &key,
size_t extent_tx,
size_t extent_ty,
size_t extent_tz) {
if (barrier_id_map_.count(key)) {
return {barrier_id_map_[key], thread_count_map_[key]};
}
size_t barrier_id =
barrier_id_map_.size() +
static_cast<size_t>(ReservedNamedBarriers::kFirstUsedBarrier);
size_t thread_count = extent_tx * extent_ty * extent_tz;
barrier_id_map_[key] = barrier_id;
thread_count_map_[key] = thread_count;
return {barrier_id, thread_count};
}
size_t CalculateThreadExtent(const IterVar &iv,
const arith::ConstIntBound &bound) {
if (!analyzer_->const_int_bound.IsBound(iv->var)) {
return 1;
}
return bound->max_value - bound->min_value + 1;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
tx_ = iv;
} else if (iv->thread_tag == "threadIdx.y") {
ty_ = iv;
} else if (iv->thread_tag == "threadIdx.z") {
tz_ = iv;
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
bool IsFullThreadExtent(const IterVar &iv,
const arith::ConstIntBound &bound) {
if (!analyzer_->const_int_bound.IsBound(iv->var)) {
return true;
}
if (!iv->dom.defined()) {
return true;
}
const auto *min_node = iv->dom->min.as<IntImmNode>();
const auto *extent_node = iv->dom->extent.as<IntImmNode>();
int64_t min = min_node->value;
int64_t extent = extent_node->value;
int64_t max = min + extent - 1;
return min == bound->min_value && max == bound->max_value;
}
// Member variables
IterVar tx_ =
IterVar(Range::FromMinExtent(0, 1), Var("tx"), IterVarType::kDataPar);
IterVar ty_ =
IterVar(Range::FromMinExtent(0, 1), Var("ty"), IterVarType::kDataPar);
IterVar tz_ =
IterVar(Range::FromMinExtent(0, 1), Var("tz"), IterVarType::kDataPar);
std::unordered_map<ThreadBoundKey, size_t> barrier_id_map_;
std::unordered_map<ThreadBoundKey, size_t> thread_count_map_;
};
Stmt TileLangThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
TileLangThreadSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_,
stmt = ThreadSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(std::move(stmt));
return ThreadPartialSyncRewriter::Rewrite(std::move(stmt));
}
using namespace tir::transform;
......
......@@ -80,13 +80,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Align dynamic shared memory allocations
if have_tma(target):
# Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(1024)(mod)
else:
# For other devices, we align to 16 bytes
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(16)(mod)
# Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks
# use an enhanced pass to simplify the dynamic symbolics
......@@ -100,6 +93,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context()
# Lower the barrier.arrive into specific initialization slot
mod = tilelang.transform.LowerSharedBarrier()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.IfStmtBinding()(mod)
......@@ -157,6 +153,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
# Global Barrier Synchronization must be applied before
......@@ -166,13 +163,24 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))(
mod)
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
# Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes
# For other devices, we align to 16 bytes
smem_align_bytes = 1024 if have_tma(target) else 16
if enable_aggressive_merge:
# Workaround, wait for a element wise synchronization pass
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
else:
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment