".github/workflows/vscode:/vscode.git/clone" did not exist on "441bae7f6600c5a48d359c69b83676657c1fa961"
Unverified Commit 64f5f847 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

[fastnn] add triton flash attention (#109)

* add triton flash attention

* add fallfack for flash attention

* add pytest skip mask for attention kernel
parent 45b04fda
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .layer_norm import FusedLayerNorm as LayerNorm
from .softmax import fused_softmax
from .attention_core import fused_attention_core
__all__ = [
"bias_dropout_add",
......@@ -8,4 +9,5 @@ __all__ = [
"bias_ele_dropout_residual",
"LayerNorm",
"fused_softmax",
"fused_attention_core",
]
\ No newline at end of file
import math
import logging
import torch
from einops import rearrange
_triton_available = True
if _triton_available:
try:
from .triton.attention_core import attention_core_triton_kernel_wrapper
except ImportError:
logging.warning("Triton is not available, fallback to old kernel.")
_triton_available = False
def _torch_attention_core(q, k, v, mask, bias):
scaling = 1. / math.sqrt(q.size(-1))
q = q * scaling
logits = torch.matmul(q, k.transpose(-1, -2))
logits += bias
logits += (1e20 * (mask - 1))[..., :, None, None, :]
weights = torch.nn.functional.softmax(logits.float(), -1).to(dtype=q.dtype)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
return weighted_avg
class FusedAttenionCoreFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, mask=None, bias=None):
if _triton_available:
o = attention_core_triton_kernel_wrapper(q, k, v, mask, bias)
else:
o = _torch_attention_core(q, k, v, mask, bias)
# ctx.save_for_backward(q, k, v, o, L, m, mask, bias)
# ctx.BLOCK = BLOCK
# ctx.grid = grid
# ctx.sm_scale = sm_scale
# ctx.BLOCK_DMODEL = Lk
return o
fused_attention_core = FusedAttenionCoreFunc.apply
\ No newline at end of file
import math
import torch
from einops import rearrange
import triton
import triton.language as tl
# CREDITS: Initially inspired by the Triton tutorial
@triton.jit
def _attention_core(Q, K, V, mask, bias, sm_scale, TMP, Out, stride_qz, stride_qh, stride_qm,
stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh,
stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX,
BATCH, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
use_mask: tl.constexpr, use_bias: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# Initialize pointers to bias, mask
if use_bias:
batch_2 = Z // BATCH
off_hz_bias = (off_hz // (batch_2 * H) * H) + (off_hz % H)
offs_base_bias = off_hz_bias * (N_CTX * N_CTX) + offs_m[:, None] * N_CTX + offs_n[None, :]
if use_mask:
off_hz_mask = (off_hz // H)
offs_base_mask = off_hz_mask * N_CTX
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q_load_mask = offs_m[:, None] < N_CTX
q = tl.load(q_ptrs, mask=q_load_mask, other=0.0)
# loop over k, v and update accumulator
for start_n in range(0, N_CTX, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
load_mask = (start_n + offs_n)[:, None] < N_CTX
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn, mask=load_mask, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= N_CTX, float("-1e20"), qk)
qk = tl.where((start_n + offs_n)[None, :] >= N_CTX, float("-1e20"), qk)
if use_bias:
bias_load_mask = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
bias_load_mask = tl.where(offs_m[:, None] >= N_CTX, 1., bias_load_mask)
bias_load_mask = tl.where((start_n + offs_n)[None, :] >= N_CTX, 1., bias_load_mask)
bias_data = tl.load(bias + offs_base_bias + start_n,
mask=(bias_load_mask == 0.),
other=0.)
qk += bias_data
if use_mask:
mask_data = tl.load(mask + offs_base_mask + offs_n + start_n,
mask=(start_n + offs_n) < N_CTX,
other=0.)
qk = tl.where(mask_data[None, :] == 0., float("-1e20"), qk)
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale, mask=(offs_m < N_CTX))
acc_scale = tl.load(TMP + off_hz * N_CTX + start_m * BLOCK_M + tl.arange(0, BLOCK_M),
mask=(start_m * BLOCK_M + tl.arange(0, BLOCK_M) < N_CTX),
other=float(0.)) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
load_mask = (start_n + offs_n)[:, None] < N_CTX
v = tl.load(v_ptrs + start_n * stride_vn, mask=load_mask, other=0.)
p = p.to(Q.dtype.element_ty)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
# l_ptrs = L + off_hz * N_CTX + offs_m
# m_ptrs = M + off_hz * N_CTX + offs_m
# tl.store(l_ptrs, l_i)
# tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
out_store_mask = offs_m[:, None] < N_CTX
tl.store(out_ptrs, acc, mask=out_store_mask)
def attention_core_triton_kernel_wrapper(q, k, v, mask, bias):
assert (q.dtype in [torch.float16,
torch.bfloat16]), "triton flash attention only support float16/bfloat16 now"
q_ori_size = list(q.size())
batch = q_ori_size[0]
if len(q_ori_size) == 5:
q = rearrange(q, 'b1 b2 h n d -> (b1 b2) h n d')
k = rearrange(k, 'b1 b2 h n d -> (b1 b2) h n d')
v = rearrange(v, 'b1 b2 h n d -> (b1 b2) h n d')
sm_scale = 1. / math.sqrt(q.size(-1))
# q *= sm_scale
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_attention_core[grid](
q,
k,
v,
mask,
bias,
sm_scale,
tmp,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
batch,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
use_mask=(mask != None),
use_bias=(bias != None),
num_warps=num_warps,
num_stages=1,
)
if len(q_ori_size) == 5:
o = rearrange(o, '(b1 b2) h n d -> b1 b2 n (h d)', b1=batch)
return o
import math
import pytest
import torch
from einops import rearrange
TEST_TRITON = False
try:
from fastfold.model.fastnn.kernel import fused_attention_core
except:
print("Skip triton attention test!")
TEST_TRITON = False
def torch_core_attention(q, k, v, mask, bias):
scaling = 1. / math.sqrt(q.size(-1))
q = q * scaling
logits = torch.matmul(q.float(), k.float().transpose(-1, -2))
logits += bias.float()
logits += (1e20 * (mask - 1))[..., :, None, None, :]
weights = torch.nn.functional.softmax(logits.float(), -1).to(dtype=q.dtype)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
return weighted_avg
@pytest.mark.skipif(TEST_TRITON == False, reason="triton is not available")
def test_fused_attention_core():
if TEST_TRITON:
batch_, chunk_, head_, d_head = 1, 8, 4, 32
test_seq_ = [32, 256, 370, 500, 512, 700, 1024, 1600]
test_dtype = [torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
tolerance_eps = {torch.float16: 1e-4, torch.bfloat16: 1e-4}
for seq_ in test_seq_:
for dtype in test_dtype:
q = torch.empty((batch_, chunk_, head_, seq_, d_head), dtype=dtype,
device="cuda").normal_(mean=0, std=.5).requires_grad_()
k = torch.empty((batch_, chunk_, head_, seq_, d_head), dtype=dtype,
device="cuda").normal_(mean=0, std=.5).requires_grad_()
v = torch.empty((batch_, chunk_, head_, seq_, d_head), dtype=dtype,
device="cuda").normal_(mean=0, std=.5).requires_grad_()
mask = torch.empty(
(batch_, chunk_, seq_), device="cuda").normal_(mean=0, std=.5) > 0
mask = mask.to(device=test_device, dtype=dtype).requires_grad_(False)
bias = torch.randn(batch_, head_, seq_, seq_).to(device=test_device,
dtype=dtype).requires_grad_(True)
ref_out = torch_core_attention(q, k, v, mask, bias)
tri_out = fused_attention_core(q, k, v, mask, bias)
# compare
torch.allclose(ref_out, tri_out, atol=tolerance_eps[dtype])
if __name__ == "__main__":
test_fused_attention_core()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment