Commit 7ccec53b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feature] Support Async Pipeline inference within if scope (#198)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic

* Fix Copy operation handling for scalar and multi-dimensional tensors

- Add special handling for scalar tensor copy operations
- Enhance error reporting in MakeIndices method with more detailed diagnostic information
- Improve SIMT loop generation to support zero-dimensional tensors
- Add explicit check and handling for scalar tensor scenarios

* Refactor Copy operation code formatting and improve readability

- Improve code formatting in MakeIndices and MakeSIMTLoop methods
- Add line breaks to enhance readability of complex ICHECK statements
- Simplify code structure in scalar tensor handling
- Remove unnecessary whitespace and improve code alignment

* Simplify GEMM example with direct kernel compilation

- Update copyright header to Tile-AI Corporation
- Remove Profiler import and usage
- Replace tilelang.lower() with tilelang.compile()
- Simplify kernel execution workflow
- Update kernel source retrieval method

* Enhance block sparse attention implementation

- Update `blocksparse_flashattn` to use 2 stages for improved performance.
- Change `block_mask_dtype` from `int8` to `bool` for better memory efficiency.
- Modify condition checks in the kernel to utilize boolean values.
- Introduce a new example for top-k sparse attention and a benchmark for native sparse attention.
- Add support for asynchronous copy in PTX and improve pipeline planning with condition handling.

* Refactor and clean up code formatting across multiple files

- Added whitespace for improved readability in `example_blocksparse_gemm.py`, `example_tilelang_nsa_fwd.py`, and `benchmark_nsa_fwd.py`.
- Enhanced code structure and alignment in `inject_ptx_async_copy.cc` and `pipeline_planning.cc`.
- Updated comments and documentation for clarity in `__init__.py` and `phase.py`.
- Ensured consistent formatting and style across the codebase.
parent 20f19611
...@@ -39,7 +39,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -39,7 +39,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 0 num_stages = 2
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -47,7 +47,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -47,7 +47,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
block_mask_dtype = "int8" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
...@@ -159,7 +159,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -159,7 +159,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0: if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum) scores_sum, logsum)
...@@ -187,8 +187,6 @@ def benchmark_topk_sparse_attention(): ...@@ -187,8 +187,6 @@ def benchmark_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math import math
import torch import torch
...@@ -34,7 +32,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -34,7 +32,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64 block_M = 64
block_N = 64 block_N = 64
num_stages = 0 num_stages = 1
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -42,7 +40,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -42,7 +40,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
block_mask_dtype = "int8" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
...@@ -196,7 +194,7 @@ def test_topk_sparse_attention(): ...@@ -196,7 +194,7 @@ def test_topk_sparse_attention():
# Run Triton kernel # Run Triton kernel
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4]) kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask) tilelang_output = kernel(q, k, v, block_mask)
# Compute reference # Compute reference
...@@ -215,8 +213,7 @@ def test_topk_sparse_attention(): ...@@ -215,8 +213,7 @@ def test_topk_sparse_attention():
print("tilelang_output", tilelang_output) print("tilelang_output", tilelang_output)
# Verify accuracy # Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
"TileLang output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen") print("Pass topk sparse attention test with qlen == klen")
......
import tilelang
import tilelang.language as T
import torch
torch.random.manual_seed(0)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
BlockMask: T.Buffer(block_mask_shape, "bool"),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
if BlockMask[by, bx, k]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
# block_mask = torch.zeros(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# block_mask = torch.ones(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# random mask
block_mask = torch.randint(0, 2, (1024 // 128, 1024 // 128, 1024 // 32)).cuda().bool()
c = kernel(a, b, block_mask)
ref_c = torch.zeros_like(c)
for i in range(1024 // 128):
for j in range(1024 // 128):
accu = torch.zeros((128, 128), dtype=torch.float32, device=a.device)
for k in range(1024 // 32):
if block_mask[i, j, k]:
accu += (
a[i * 128:(i + 1) * 128, k * 32:(k + 1) * 32].to(torch.float32)
@ b[k * 32:(k + 1) * 32, j * 128:(j + 1) * 128].to(torch.float32))
ref_c[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = accu.to(torch.float16)
# ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print(kernel.get_kernel_source())
...@@ -211,9 +211,10 @@ if __name__ == "__main__": ...@@ -211,9 +211,10 @@ if __name__ == "__main__":
if (not args.tune): if (not args.tune):
program = flashattn( program = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)( batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
block_M=64, block_N=64, num_stages=0, threads=128) block_M=64, block_N=64, num_stages=1, threads=128)
ref_program = partial(ref_program, is_causal=is_causal) ref_program = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3]) kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
...@@ -40,7 +40,7 @@ def native_sparse_attention(batch, ...@@ -40,7 +40,7 @@ def native_sparse_attention(batch,
G = groups G = groups
BS = block_S BS = block_S
BK = BV = block_T BK = BV = block_T
num_stages = 0 num_stages = 2
threads = 32 threads = 32
@T.prim_func @T.prim_func
...@@ -140,6 +140,7 @@ if __name__ == "__main__": ...@@ -140,6 +140,7 @@ if __name__ == "__main__":
scale=scale, scale=scale,
) )
kernel = tilelang.compile(program, out_idx=-1) kernel = tilelang.compile(program, out_idx=-1)
torch.random.manual_seed(0) torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
......
...@@ -158,84 +158,6 @@ def naive_nsa(q: torch.Tensor, ...@@ -158,84 +158,6 @@ def naive_nsa(q: torch.Tensor,
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype) return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
def naive_nsa_simple( def naive_nsa_simple(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
...@@ -227,11 +227,12 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -227,11 +227,12 @@ class PipelineRewriter : public StmtExprMutator {
public: public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs, const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info) const For &pipeline_loop, const PipelineInfo &pipeline_info,
PrimExpr predicate_condition = PrimExpr())
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {} pipeline_info_(pipeline_info),
predicate_condition_(predicate_condition) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the // Step 1: Analyze accesses to the buffers in the pipeline and compute the
...@@ -636,6 +637,7 @@ private: ...@@ -636,6 +637,7 @@ private:
// Async related // Async related
std::map<int, AsyncStateLocal> async_states_local; std::map<int, AsyncStateLocal> async_states_local;
PrimExpr normalized_access_index;
for (const Block &block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage; int stage = pipeline_info_.at(block).stage;
...@@ -658,7 +660,7 @@ private: ...@@ -658,7 +660,7 @@ private:
// - "producer_head" if this stage is an async producer // - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written // - "consumer_head" if this stage reads from asynchronously written
// buffers. // buffers.
PrimExpr normalized_access_index = normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop // Adjust the block predicate and the body according to the final loop
...@@ -668,10 +670,15 @@ private: ...@@ -668,10 +670,15 @@ private:
Var loop_iter = Downcast<Var>(new_loop_var); Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
} }
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (predicate_condition_.defined()) {
BlockNode *n = new_block.CopyOnWrite();
n->body = IfThenElse(
Substitute(predicate_condition_,
{{pipeline_loop_->loop_var, normalized_access_index}}),
n->body);
}
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
...@@ -687,7 +694,6 @@ private: ...@@ -687,7 +694,6 @@ private:
PopulateWaitCounts(new_blocks, &async_states_local); PopulateWaitCounts(new_blocks, &async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
Stmt new_loop{nullptr}; Stmt new_loop{nullptr};
if (stmts.empty()) { if (stmts.empty()) {
...@@ -713,7 +719,6 @@ private: ...@@ -713,7 +719,6 @@ private:
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), NullOpt, preserved_annotations); std::move(new_loop), NullOpt, preserved_annotations);
} }
// Update producer heads in the global async states. // Update producer heads in the global async states.
for (const auto &[stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
async_states[stage_id].producer_head += extent; async_states[stage_id].producer_head += extent;
...@@ -728,6 +733,7 @@ private: ...@@ -728,6 +733,7 @@ private:
Array<Buffer> pipeline_allocs_; Array<Buffer> pipeline_allocs_;
For pipeline_loop_; For pipeline_loop_;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
PrimExpr predicate_condition_;
int max_stage_ = -1; int max_stage_ = -1;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_; Array<Block> ordered_stmts_;
...@@ -842,6 +848,7 @@ private: ...@@ -842,6 +848,7 @@ private:
// can be direct child of the for-loop. If the for-loop has BlockRealize as // can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block. // its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr}; Stmt pipeline_body{nullptr};
PrimExpr predicate_condition{nullptr};
Array<Buffer> pipeline_allocs; Array<Buffer> pipeline_allocs;
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) { if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block; const auto &block = realize->block;
...@@ -849,7 +856,15 @@ private: ...@@ -849,7 +856,15 @@ private:
ICHECK(buffer->IsInstance<BufferNode>()); ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
predicate_condition = if_then_else->condition;
} else {
pipeline_body = block->body; pipeline_body = block->body;
}
pipeline_allocs = block->alloc_buffers; pipeline_allocs = block->alloc_buffers;
} else { } else {
pipeline_body = for_node->body; pipeline_body = for_node->body;
...@@ -927,8 +942,9 @@ private: ...@@ -927,8 +942,9 @@ private:
ValidatePipelineBody(pipeline_info, original_order); ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, Stmt pipeline =
GetRef<For>(op), pipeline_info) PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info, predicate_condition)
.BuildPipeline(); .BuildPipeline();
if (const auto *realize = op->body.as<BlockRealizeNode>()) { if (const auto *realize = op->body.as<BlockRealizeNode>()) {
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "storage_access.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
namespace tvm {
namespace tl {
using namespace tir;
class PTXAsyncCopyInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode *attr) {
if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
auto body = this->VisitStmt(attr->body);
in_async = false;
return body;
}
return StmtMutator::VisitStmt_(attr);
}
Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store,
bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() ==
store->indices[0]->dtype.lanes())
<< load->indices[0] << " vs. " << store->indices[0] << " with lanes "
<< load->indices[0]->dtype.lanes() << " vs. "
<< store->indices[0]->dtype.lanes();
const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type =
GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type =
GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type "
"annotation.";
int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the
// dst shared memory is a byte buffer generated by merging dynamic
// shared memory.
ICHECK(store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "shared");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according
// to their "value" dtype. Their "indices" are supposed to be applied
// after such pointer cast, for example:
// ((*float16)(byte_buffer))[buffer->indices] = fp16_value; To replace
// BufferStore/Load with cp.async, we need to multiply the store index
// by the byte size of the "value" dtype, to get the correct offset
// into the byte buffer.
index_factor = src_elem_type->bytes();
}
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {
store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated
// cp.async
if (predicated) {
args.push_back(predicate_value);
}
return Evaluate(Call(store->buffer->dtype,
tvm::tir::builtin::ptx_cp_async(), args));
}
// Predicated load don't support vectorized indexing.
if (!predicated) {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by
// merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) +
// x8(17408))] = A_global[ramp(...),1, 8)]
auto *add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>())
return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>())
return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base,
add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by
// merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) +
// x8(17408))] = A_global[ramp(...),1, 8)]
auto *add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>())
return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>())
return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base,
add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes),
predicate_value}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
Stmt VisitStmt_(const BufferStoreNode *store) {
if (in_async && (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn")) {
if (auto *load = store->value.as<BufferLoadNode>()) {
return InjectPTX(load, store);
} else if (auto *call = store->value.as<CallNode>()) {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) &&
call->args.size() == 3) {
if (auto *load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value
// used by cp.async ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto *b = call->args[2].as<BroadcastNode>()) {
if (auto *f = b->value.as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
} else if (auto *i = b->value.as<IntImmNode>()) {
else_value_is_zero = i->value == 0;
}
}
if (auto *f = call->args[2].as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
} else if (auto *i = call->args[2].as<IntImmNode>()) {
else_value_is_zero = i->value == 0;
}
if (else_value_is_zero) {
return InjectPTX(load, store, true, call->args[0]);
}
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
private:
bool in_async{false};
};
using namespace tir::transform;
tvm::transform::Pass InjectPTXAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = PTXAsyncCopyInjector()(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectPTXAsyncCopy")
.set_body_typed(InjectPTXAsyncCopy);
} // namespace tl
} // namespace tvm
...@@ -34,8 +34,6 @@ namespace tl { ...@@ -34,8 +34,6 @@ namespace tl {
using namespace tir; using namespace tir;
namespace {
/*! /*!
* \brief Check whether two regions have intersections. * \brief Check whether two regions have intersections.
* \param region1 The first region. * \param region1 The first region.
...@@ -56,8 +54,6 @@ bool MayConflict(Region region1, Region region2) { ...@@ -56,8 +54,6 @@ bool MayConflict(Region region1, Region region2) {
return true; return true;
} }
} // namespace
class PipelinePlanner : public StmtExprMutator { class PipelinePlanner : public StmtExprMutator {
public: public:
static Stmt Substitute(const PrimFunc &f) { static Stmt Substitute(const PrimFunc &f) {
...@@ -88,20 +84,24 @@ private: ...@@ -88,20 +84,24 @@ private:
/*body*/ stmt); /*body*/ stmt);
Array<Array<BufferRegion>> access = Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_); GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
PipelineStageInfo pinfo; PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]); pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]); pinfo.writes = std::move(access[1]);
pinfo.original_order = idx; pinfo.original_order = idx;
// copy stage should only have one reads and one writes // copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) { bool write_to_shared = false;
bool read_from_global = false;
for (auto region : pinfo.reads) for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") if (region->buffer.scope() == "global")
pinfo.copy_stage = true; read_from_global = true;
for (auto region : pinfo.writes) for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") if (region->buffer.scope() == "shared" ||
pinfo.copy_stage = true; region->buffer.scope() == "shared.dyn")
} write_to_shared = true;
pinfo.copy_stage = write_to_shared && read_from_global;
return std::move(pinfo); return std::move(pinfo);
} }
...@@ -118,14 +118,26 @@ private: ...@@ -118,14 +118,26 @@ private:
ICHECK(buffer->IsInstance<BufferNode>()); ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
pipeline_body = block->body; pipeline_body = block->body;
} else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
// should assert else case is nullptr
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
} else {
LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
"because it is not a SeqStmt or IfThenElse";
}
} else { } else {
pipeline_body = loop->body; pipeline_body = loop->body;
} }
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>(); const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline " CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline "
"should be SeqStmt, got " "should be SeqStmt, got "
<< loop->body->GetTypeKey(); << pipeline_body->GetTypeKey() << " " << pipeline_body;
CHECK(num_stages >= 1); CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial); CHECK(loop->kind == ForKind::kSerial);
...@@ -156,10 +168,12 @@ private: ...@@ -156,10 +168,12 @@ private:
return r->buffer == write->buffer && return r->buffer == write->buffer &&
MayConflict(r->region, write->region); MayConflict(r->region, write->region);
}) != pinfo.writes.end()) { }) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer " LOG(FATAL) << "Pipeline planning error: Multiple writes to "
"region in the pipeline " "overlapping buffer regions detected. "
"planning pass: " << "Stage " << pinfo.original_order << " and stage " << i
<< pipeline_body_seq->seq[pinfo.original_order]; << " are both writing to buffer '" << write->buffer->name
<< "' with overlapping regions. This is not supported "
"in pipeline planning.";
} }
} }
} }
......
...@@ -44,6 +44,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -44,6 +44,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# TODO(lei): may need a pass to fuse the if-then-else in the
# pipeline loop when we meet dynamic branch.
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod) mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
...@@ -74,7 +76,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -74,7 +76,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
......
...@@ -225,3 +225,14 @@ def VectorizeLoop(enable_vectorize: bool = True): ...@@ -225,3 +225,14 @@ def VectorizeLoop(enable_vectorize: bool = True):
The result pass The result pass
""" """
return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore
def InjectPTXAsyncCopy():
"""Rewrite global to shared memory copy on CUDA with asynchronous copy.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment