"docs/vscode:/vscode.git/clone" did not exist on "6b1490c9517d6c99d4a0c16bc870909f88bfcbd2"
Unverified Commit adcba275 authored by alex_xiao's avatar alex_xiao Committed by GitHub
Browse files

Add Flash Attn example on amd mi300 series (#682)



* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668)

- Enhanced buffer index handling to address precision issues by removing redundant operations.
- Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection.
- Updated related documentation to reflect changes in buffer management practices.

* Remove obsolete test script for AMD example, streamlining the examples directory.

* Remove unused dtype_size variable in AMD example script to streamline code.

* Add input configuration file and update AMD example script for enhanced flexibility

- Introduced a new input.txt file for configurable parameters.
- Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack.
- Streamlined the main function for better clarity and organization.
- Added a new test script to facilitate running the example with specified parameters.

* Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations

- Deleted input.txt and test.sh files as they are no longer needed.
- Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance.
- Reintroduced swizzle usage in the kernel for better performance.

* Refactor AMD example script for FlashAttention-2

- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`.
- Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls.
- Removed outdated comments and improved code organization for better readability.

* Refactor formatting in AMD FlashAttention example script

- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function.
- Streamlined the `main` function parameter formatting for consistency.
- Removed unnecessary blank lines to enhance overall code organization.

* Update example_amd_flash_attn_fwd.py

---------
Co-authored-by: default avatarxinxyxiao <xinyxiao@amd.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 05f2fc6d
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import itertools
import argparse
from functools import partial
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [64, 128, 256]
block_N = [32, 64, 128]
threads = [128, 256, 512]
num_split_q = [32, 64, 128]
num_stages = [0, 1, 2]
enable_rasterization = [True, False]
k_pack = [1, 2]
valid_configs = []
for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
num_stages, enable_rasterization, k_pack):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k
})
valid_configs.append({
'block_M': 64,
'block_N': 64,
'num_split_q': 64,
'threads': 256,
'num_stages': 1,
'enable_rasterization': True,
'k_pack': 2
})
return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
):
scale = (1.0 / dim)**0.5 * 1.44269504
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
v_vec_size = 4
vec_size = 4 * k_pack
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(10, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32")
bx[0] = b_split
with T.While(bx[0] < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx[0]
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
P_shared = T.alloc_shared([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j,
acc_s[i, j], -T.infinity(acc_s.dtype))
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
l_i[i] *= sf
scale_factor[i] = sf
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
row_sum = T.alloc_fragment([block_M], accum_dtype)
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
T.copy(acc_s, P_shared)
T.sync_threads()
T.gemm(P_shared, V_shared, acc_o)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
bx[0] = current_bx + num_split_q
return main
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
print("Starting autotuning for FlashAttention-V2...")
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
print(f"Autotuning finished. Best Configuration: {kernel.config}")
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=100)
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100)
print(
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
...@@ -22,7 +22,7 @@ struct MinOp { ...@@ -22,7 +22,7 @@ struct MinOp {
} }
}; };
template <class Reducer, int threads, int scale> struct AllReduce { template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 || static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 || threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2); threads == 16 || threads == 8 || threads == 4 || threads == 2);
...@@ -43,7 +43,7 @@ template <class Reducer, int threads, int scale> struct AllReduce { ...@@ -43,7 +43,7 @@ template <class Reducer, int threads, int scale> struct AllReduce {
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
} else { } else {
return AllReduce<Reducer, offset, scale>::run(x, red_buf); return AllReduce<Reducer, offset, scale, thread_offset>::run(x, red_buf);
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment