Unverified Commit f07f31c1 authored by alex_xiao's avatar alex_xiao Committed by GitHub
Browse files

[AMD] Fix amd tir&add examples (#784)



* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668)

- Enhanced buffer index handling to address precision issues by removing redundant operations.
- Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection.
- Updated related documentation to reflect changes in buffer management practices.

* Remove obsolete test script for AMD example, streamlining the examples directory.

* Remove unused dtype_size variable in AMD example script to streamline code.

* Add input configuration file and update AMD example script for enhanced flexibility

- Introduced a new input.txt file for configurable parameters.
- Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack.
- Streamlined the main function for better clarity and organization.
- Added a new test script to facilitate running the example with specified parameters.

* Remove input configuration file and obsolete test script; enhance AMD example with swizzle layout annotations

- Deleted input.txt and test.sh files as they are no longer needed.
- Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance.
- Reintroduced swizzle usage in the kernel for better performance.

* Refactor AMD example script for FlashAttention-2

- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`.
- Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls.
- Removed outdated comments and improved code organization for better readability.

* Refactor formatting in AMD FlashAttention example script

- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function.
- Streamlined the `main` function parameter formatting for consistency.
- Removed unnecessary blank lines to enhance overall code organization.

* Update example_amd_flash_attn_fwd.py

* Enhance AMD example script and update CI workflows

- Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization.
- Added new CI workflows for AMD and documentation publishing.
- Updated various requirements files to include necessary dependencies.
- Introduced new test cases and examples for better coverage and functionality.
- Refactored existing code for improved readability and maintainability.

* Remove redundant tool cache cleanup step in AMD CI workflow

* Remove `torch` dependency from `requirements-rocm.txt` to streamline requirements.

* Add new AMD FlashAttention example and test script

- Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang.
- Added `test.sh` script to facilitate running the new example with specified parameters.
- Enhanced the overall structure and organization of the example for better clarity and usability.

* Update configurations in `example_amd_flash_attn_fwd.py` for autotuner

- Reduced the number of threads and `num_split_q` options for improved performance.
- Adjusted `panel_size` options to streamline configuration settings.

* Update submodule 'tvm' to commit 6ccc74f622c7ec4ac25d430d0f6546e7b9edb217

* Update submodule 'tvm' to commit 14ff70ab142b9e5a31bbf9c7923c8a697d41e86c

* Add example for AMD Flash Attention backward pass implementation

- Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang.
- Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps.
- Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters.
- Included reference implementation for validation against PyTorch's attention mechanism.

This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications.

* Enhance AMD Flash Attention example with additional testing capabilities

- Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation.
- Improved the main function to allow for better parameter configuration and benchmarking.
- Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example.

This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications.

* Update submodule TVM to commit a64a5926a6e59f5417ef2501f9d88b467337cf6a

* Refactor HIP intrinsic rules to CUDA

- Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules.
- Adjusted include paths for better organization and clarity in the code structure.

* Update AMD CI workflow to uninstall specific PyTorch packages before installation

- Removed the installation of `flash_attn==2.5.8` to streamline the CI process.
- Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts.

* Remove unused shared memory allocations in AMD Flash Attention backward example

- Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance.
- This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead.

* Remove unnecessary pip uninstall command from AMD CI workflow

- Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions.
- This change simplifies the CI process and reduces potential overhead during package management.

* Refactor DispatchHIPWarpActiveMask function in HIP intrinsic rules

- Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity.
- Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues.

* Refactor formatting of HIP intrinsic rule registrations

- Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining.
- No functional changes were made; this update focuses on code style improvements to enhance maintainability.

* Update file name and documentation for HIP intrinsic rules

- Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules.
- Updated the file documentation to clarify its purpose as related to HIP rather than CUDA.

* Enhance DispatchHIPShuffle function with clang-analyzer comments

- Added NOLINTBEGIN and NOLINTEND comments to the DispatchHIPShuffle function to suppress clang-analyzer warnings related to inner pointer usage.
- This change improves code clarity and maintains compliance with static analysis tools.

* lint fix

* fix

---------
Co-authored-by: default avatarxinxyxiao <xinyxiao@amd.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 3cfefc8e
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
@tilelang.jit(out_idx=[3, 4])
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_qk]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=1):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
for i, j in T.Parallel(block_M, dim_v):
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])
for i, j in T.Parallel(block_M, dim_qk):
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N,
groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
return dq, dk, dv, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
...@@ -32,12 +32,12 @@ def get_configs(): ...@@ -32,12 +32,12 @@ def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" """Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [32, 64, 128, 256] block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256] block_N = [32, 64, 128, 256]
threads = [64, 128, 192, 256, 512, 1024] threads = [128, 256, 512]
num_split_q = [32, 64, 128, 256, 256] num_split_q = [64, 128, 256]
num_stages = [0] num_stages = [0]
enable_rasterization = [True] enable_rasterization = [True]
k_pack = [2] k_pack = [2]
panel_size = [7, 8, 9, 10] panel_size = [7, 8]
qk_coalesced_width = [8] qk_coalesced_width = [8]
v_coalesced_width = [4] v_coalesced_width = [4]
......
/root/miniconda3/envs/py312/bin/python3 examples/amd/example_amd_flash_attn_fwd.py \
--batch 2 \
--heads 16 \
--seq_len 4096 \
--dim 128 \
--is_causal \
--groups 2
/root/composable_kernel/build/bin/tile_example_fmha_fwd \
-b=2 -h=16 -s=4096 -d=128 -mask=t -v=1 -warmup=5 -repeat=20
/*!
* \file intrin_rule_hip.cc
* \brief HIP intrinsic rules.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "target/intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
// Add float suffix to the intrinsics, HIP fast math.
using tir::FLowerIntrinsic;
struct HIPMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
case 32:
return name + 'f';
case 16: {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
}
default:
return "";
}
} else if (t.is_bfloat16()) {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
} else if (t.is_int() || t.is_uint()) {
switch (t.bits()) {
case 32:
return "__" + name;
case 64:
return "__" + name + "ll";
default:
return "";
}
}
return "";
}
};
struct HIPFastMath : public HIPMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return HIPMath::operator()(t, name);
}
return "";
}
};
struct HIPFastMathTan : public HIPMath {
std::string operator()(DataType t, std::string name) const {
if (t.is_float()) {
switch (t.bits()) {
case 64:
return name;
case 32:
return name + 'f';
case 16:
return std::string("h") + name;
default:
return "";
}
}
return "";
}
};
struct HIPPopcount {
std::string operator()(DataType t, std::string name) const {
if (t.is_uint()) {
switch (t.bits()) {
case 32:
return "__popc";
case 64:
return "__popcll";
default:
return "";
}
}
return "";
}
};
struct HIPWarpIntrinsic {
const Op operator()(DataType t, const Op &orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.hip.__shfl_sync");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.hip.__shfl_up_sync");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.hip.__shfl_down_sync");
}
}
};
static PrimExpr DispatchHIPWarpActiveMask(const PrimExpr &e) {
const CallNode *call = e.as<CallNode>();
ICHECK(call != nullptr);
return Call(call->dtype, Op::Get("tir.hip.__activemask"), {});
}
template <typename T> static PrimExpr DispatchHIPShuffle(const PrimExpr &e) {
// NOLINTBEGIN(clang-analyzer-cplusplus.InnerPointer)
const CallNode *call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> hip_args{
{call->args[0], call->args[1], call->args[2], call->args[3]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), hip_args);
// NOLINTEND(clang-analyzer-cplusplus.InnerPointer)
}
TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath, /*dtype_from_arg=*/true>);
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.ceil")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.fabs")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.exp10")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.log2")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.log10")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPFastMathTan>);
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPFastMath>);
TVM_REGISTER_OP("tir.sinh")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.atan")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.sqrt")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"hip.FLowerIntrinsic", DispatchPureExtern<HIPMath>);
TVM_REGISTER_OP("tir.popcount")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPPopcount>);
TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchHIPShuffle<HIPWarpIntrinsic>);
TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchHIPShuffle<HIPWarpIntrinsic>);
TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchHIPShuffle<HIPWarpIntrinsic>);
TVM_REGISTER_OP("tir.tvm_warp_activemask")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchHIPWarpActiveMask);
TVM_REGISTER_OP("tir.fmod")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic",
DispatchPureExtern<HIPMath>);
// Register low-level builtin ops.
TVM_REGISTER_OP("tir.hip.__shfl_sync")
.set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.")
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.add_argument("width", "Expr",
"The warp thread width, must be a power of 2.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_sync")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque))
.set_attr<bool>("hip.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.hip.__shfl_up_sync")
.set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.")
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.add_argument("width", "Expr",
"The warp thread width, must be a power of 2.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_up_sync")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque))
.set_attr<bool>("hip.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.hip.__shfl_down_sync")
.set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.")
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr",
"The source lane id offset to be subtracted.")
.add_argument("width", "Expr",
"The warp thread width, must be a power of 2.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_down_sync")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque))
.set_attr<bool>("hip.need_warp_shuffle", true);
TVM_REGISTER_OP("tir.hip.__activemask")
.set_num_inputs(0)
.set_attr<TGlobalSymbol>("TGlobalSymbol", "__activemask")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<bool>("hip.need_warp_shuffle", true);
} // namespace intrin
} // namespace codegen
} // namespace tvm
\ No newline at end of file
...@@ -94,8 +94,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -94,8 +94,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Infer memory layouts for fragments and shared memory # Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod) mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations # Lower high-level tile operations to low-level operations
print("LowerTileOp")
print(mod.script())
mod = tilelang.transform.LowerTileOp()(mod) mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map # Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod) mod = tilelang.transform.LowerL2Persistent()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment