Commit 4b535e64 authored by zhangshao's avatar zhangshao
Browse files

update

parent 34e67b1e
import math
import time
import pytest
import torch
import random
import torch.nn.functional as F
import csv
from einops import rearrange, repeat
# from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
from flash_attn import vllm_flash_attn_with_kvcache as _flash_attn_with_kvcache
max_seqlen=8192*5
# max_seqlen=4352
eager=True
# eager=False
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores = scores / softcap
scores = scores.tanh()
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
has_leftpad,
paged_kv_block_size,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
alibi,
new_kv,
dtype,
batch_size,
qhead,
kv_head,
prof=False,
):
# if seqlen_q > seqlen_k and new_kv:
# pytest.skip()
# if not new_kv and rotary_fraction > 0.0:
# pytest.skip()
# if has_batch_idx and paged_kv_block_size is not None:
# pytest.skip()
# if has_leftpad and paged_kv_block_size is not None:
# pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
# batch_size = 64
# nheads = 32
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, qhead, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
nheads_k = kv_head
# alloc k v
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
# 生成kvcache
if paged_kv_block_size is None:
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None
else:
(
k_cache,
v_cache,
block_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
)
seq_lens = [seqlen_k for _ in range(batch_size)]
cache_seqlens = torch.tensor(seq_lens, dtype=torch.int, device=device)
if has_leftpad:
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
for i in range(batch_size)])
else:
cache_leftpad = None
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_leftpad:
key_padding_mask = torch.logical_and(
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
)
if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size
]
else:
cache_batch_idx = None
alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads_k // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads_k // nheads_k)
q_scale = torch.tensor([0.5], dtype=torch.float32,device=device)
k_scale = torch.tensor([0.5], dtype=torch.float32,device=device)
v_scale = torch.tensor([0.25], dtype=torch.float32,device=device)
# new_type = torch.float8_e5m2
# new_type = torch.float8_e4m3fn
new_type = dtype
k_cache_paged = k_cache_paged.permute(0, 2, 1, 3).contiguous().to(new_type)
v_cache_paged = v_cache_paged.permute(0, 2, 3, 1).contiguous().to(new_type)
max_seqlen_k=seqlen_k
# max_seqlen_k=32768
# warm
for i in range(10):
out = _flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
cache_seqlens=cache_seqlens,
block_table=block_table,
causal=causal,
max_seqlen_k=max_seqlen_k,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
)
# prof time
torch.cuda.synchronize()
repeat_num = 100
start_time = time.time()
for i in range(repeat_num):
out = _flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
cache_seqlens=cache_seqlens,
block_table=block_table,
causal=causal,
max_seqlen_k=max_seqlen_k,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
)
torch.cuda.synchronize()
end_time = time.time()
fc1_espl = end_time - start_time
DCU_time = fc1_espl *1000*1000 / repeat_num
IO_bytes = batch_size*seqlen_k*kv_head*d*2*k_cache_paged.element_size() #kv cache size to read
IO_bytes += batch_size*qhead*d*q.element_size() #q size to read
IO_bytes += (seqlen_k//512+1)*batch_size*qhead*d*2*2 # temp to write and read
IO_bytes += batch_size*qhead*d*2 #output to write
IO_speed = IO_bytes/DCU_time/1024/1024/1024*1000*1000
print('FA_kvcache bs=', batch_size,' seqlen=',seqlen_k,' qhead=',qhead, ' kv_head=',kv_head, ' time is', '{:.2f}'.format(DCU_time), 'us Bandwidth=','{:.2f}'.format(IO_speed),'GB/s')
res_list = [paged_kv_block_size, batch_size, seqlen_k, d, qhead, kv_head, DCU_time,IO_speed]
# print('FA_kvcache bs=', batch_size,' seqlen=',seqlen_k,' qhead=',qhead, ' kv_head=',kv_head, ' time is', '{:.2f}'.format(DCU_time), 'us')
# res_list = [paged_kv_block_size, batch_size, seqlen_k, d, qhead, kv_head, DCU_time]
return res_list
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
if paged_kv_block_size is None:
k_cache_select = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else:
k_cache_select = rearrange(
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
mult = 3 if not alibi else 5
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
num_blocks = 50000
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
if eager:
max_num_blocks_per_seq = (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size
else:
max_num_blocks_per_seq = (max_seqlen + paged_kv_block_size - 1) // paged_kv_block_size
block_tables = []
for _ in range(batch_size):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
# # randperm torch.randperm
# block_table = rearrange(
# torch.randperm(batch_size*max_seqlen//paged_kv_block_size, dtype=torch.int32, device=device),
# "(b nblocks) -> b nblocks",
# b=batch_size,
# )
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_tables.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_tables.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_tables, k_cache_paged, v_cache_paged, num_blocks
# mha
if __name__ == "__main__":
# HIP_VISIBLE_DEVICES=6 python test_kvcache.py
#config = [(1,16,16),(1,32,32),(1,32,4),(64,32,4),(1,52,4),(64,52,4),(1,16,2),(64,16,2),(1,26,2),(64,26,2),(1,8,1),(64,8,1),(1,13,1),(64,13,1)]
# config = [(120,6,1),(120,8,1),(120,28,4),(120,16,2),(120,20,4)]
# seq_lens=[600,1200,2400,4800]
random.seed(0)
torch.random.manual_seed(0)
# batchsize = [4,8,16,24,32,48,56,64,72,88,120]
# batchsize = [1,2,4,8,16,24,32,40,48,56,64,72,80,88,96,104]
batchsize = [1,8,32,128]
# batchsize = [128,256,512]
# batchsize = [16,24,32,40,48,56,64,72,80,88,96] #70B,235B
# batchsize = [24,32,40,48,56] #30B
# batchsize = [40,48,56,64,72,80,88,96] #8B
# head = [(32,2)]
# head = [(12,1)]
head = [(16,2),(32,8)]
# head = [(15,1),(16,1)]
# head = [(8,1),(9,1),(10,1),(11,1),(12,1),(13,1),(14,1),(15,1),(16,1),(17,1),(18,1),(19,1),(20,1),(21,1),(22,1),(23,1),(24,1),(25,1),(26,1),(27,1),(28,1),(29,1),(30,1),(31,1),(32,1)]
# head = [(4,1),(8,1),(12,1),(16,1),(24,1)]
# seq_lens=[100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300]
# seq_lens=[2000,2100,2200,2300,2400,2500,2600,2700]
seq_lens=[2048,8192,32768]
# seq_lens=[8192,128000]
# seq_lens=[1000,1100,1350,1500,1650,1800,2000,2300,2600,3000,3300,3500,3700,4000,4096,4100,4200,4300,4500,4700,5000]
# seq_lens=[3000,3300,3500,3800,4000,4300,4500,4800,5000]
# seq_lens=[500,700,1000,1300,2000,3000,4000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,15000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,16000,16500,17000,17500,18000,18500,19000,19500,20000]
# seq_lens=[16000,17000,18000,19000,20000,21000]
# heads = [8, 10, 16, 18, 20, 28, 30, 32, 38, 40, 48, 50, 58, 60, 64, 68, 70]
# batchs = [64]
# seq_lens=[1500]
dtype=torch.float16
# dtype=torch.bfloat16
print(dtype)
res_time = []
for qh,kh in head:
for bs in batchsize:
for seq in seq_lens:
# if (not (seq>=10000 and bs>16)) and seq<max_seqlen:
if True:
prof_time = test_flash_attn_kvcache(
seqlen_q=1,
seqlen_k=seq, #128 512
d=128, # 64 128 160 256
has_batch_idx=False,
has_leftpad=False,
paged_kv_block_size=64, #16 256
rotary_fraction=0.0,
rotary_interleaved=False,
seqlen_new_eq_seqlen_q=True,
causal=True, # 因果注意力机制
local=False, # 局部注意力
alibi=False,
new_kv=False,
dtype=dtype,
batch_size=bs,
qhead=qh,
kv_head=kh,
prof=False # 运行单次
)
res_time.append(prof_time)
with open('kvcache_time.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
for row in res_time:
writer.writerow(row)
...@@ -225,6 +225,7 @@ hg_varlen_bwd_bshd(const at::Tensor &dout, ...@@ -225,6 +225,7 @@ hg_varlen_bwd_bshd(const at::Tensor &dout,
static const bool print_param = get_env_("FLASH_ATTENTION_PRINT_PARAM"); static const bool print_param = get_env_("FLASH_ATTENTION_PRINT_PARAM");
static const bool print_hg_path = get_env_("FLASH_ATTENTION_PRINT_HG"); static const bool print_hg_path = get_env_("FLASH_ATTENTION_PRINT_HG");
static const bool disable_varlen_tiny_dim64 = get_env_("FLASH_ATTENTION_DISABLE_VARLEN_TINY_DIM64"); static const bool disable_varlen_tiny_dim64 = get_env_("FLASH_ATTENTION_DISABLE_VARLEN_TINY_DIM64");
static const bool enable_hg_varlen = get_env_("FLASH_ATTENTION_ENABLE_HG_VARLEN");
#ifdef HAS_HG_DISPATCH #ifdef HAS_HG_DISPATCH
...@@ -741,20 +742,23 @@ void run_mha_fwd_prefix_kv_fp8(Flash_fwd_params &params, cudaStream_t stream, bo ...@@ -741,20 +742,23 @@ void run_mha_fwd_prefix_kv_fp8(Flash_fwd_params &params, cudaStream_t stream, bo
void run_mha_fwd_unified(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd_unified(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&]
{
// using elem_type = cutlass::half_t; // using elem_type = cutlass::half_t;
// using elem_type = cutlass::float_e5m2_t; // using elem_type = cutlass::float_e5m2_t;
// HEADDIM_SWITCH_FP8(params.d, [&] { // HEADDIM_SWITCH_FP8(params.d, [&] {
constexpr static int kHeadDim = 256;
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.d != 256) { if (params.d != 256 && params.d != 128) {
TORCH_CHECK(false, "unified attn only support dim=256"); TORCH_CHECK(false, "unified attn only support dim=128/256");
} }
run_mha_fwd_unified_dispatch<elem_type, kHeadDim, Is_causal>(params, stream); run_mha_fwd_unified_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}); });
// }); // });
}); });
});
} }
void run_mha_fwd_mla(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd_mla(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
params.num_splits=1; params.num_splits=1;
if(params.is_fp8==true) if(params.is_fp8==true)
...@@ -961,7 +965,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -961,7 +965,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
alibi_slopes_, s_aux_, alibi_slopes_, s_aux_,
skip_softmax_threshold_scale_factor, skip_softmax_threshold_scale_factor,
is_causal, seqlen_q, seqlen_k, is_causal, seqlen_q, seqlen_k,
window_size_left, window_size_right)) { window_size_left, window_size_right)&&(!is_bhsd)) {
if (print_param || print_hg_path) { if (print_param || print_hg_path) {
printf("[flash_attn] HG PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d)\n", printf("[flash_attn] HG PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d)\n",
is_bhsd ? "bhsd" : "bshd", is_bhsd ? "bhsd" : "bshd",
...@@ -2019,7 +2023,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -2019,7 +2023,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if (can_use_hg_dense_bwd( if (can_use_hg_dense_bwd(
q.scalar_type(), alibi_slopes_, q.scalar_type(), alibi_slopes_,
head_size, head_size_value, is_causal, seqlen_q, seqlen_k, head_size, head_size_value, is_causal, seqlen_q, seqlen_k,
window_size_left, window_size_right, p_dropout)) { window_size_left, window_size_right, p_dropout)&&(!is_bhsd)) {
if (print_param || print_hg_path) { if (print_param || print_hg_path) {
printf("[flash_attn] HG BWD PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) dout=(%d,%d,%d,%d)\n", printf("[flash_attn] HG BWD PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) dout=(%d,%d,%d,%d)\n",
is_bhsd ? "bhsd" : "bshd", is_bhsd ? "bhsd" : "bshd",
...@@ -2308,7 +2312,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -2308,7 +2312,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
#ifdef HAS_HG_DISPATCH #ifdef HAS_HG_DISPATCH
if (can_use_hg_varlen_bwd( if (enable_hg_varlen
&& can_use_hg_varlen_bwd(
q.scalar_type(), alibi_slopes_, q.scalar_type(), alibi_slopes_,
head_size, head_size_value, total_q, total_k, max_seqlen_k, head_size, head_size_value, total_q, total_k, max_seqlen_k,
window_size_left, window_size_right, p_dropout)) { window_size_left, window_size_right, p_dropout)) {
...@@ -4459,7 +4464,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) { ...@@ -4459,7 +4464,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) {
return std::make_tuple(results[0], results[1]); return std::make_tuple(results[0], results[1]);
}); });
} }
at::Tensor mean_pool_fast(const at::Tensor &input,int blk,const c10::optional<at::Tensor> &mean);
// ============================================================================ // ============================================================================
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...@@ -4484,6 +4489,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -4484,6 +4489,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("varlen_bwd_attnmask", &mha_varlen_bwd_attnmask, "Backward pass (variable length), with explicit attention mask"); m.def("varlen_bwd_attnmask", &mha_varlen_bwd_attnmask, "Backward pass (variable length), with explicit attention mask");
m.def("paged_attention", &paged_attention, "Forward pass, with KV-cache"); m.def("paged_attention", &paged_attention, "Forward pass, with KV-cache");
m.def("fwd_sparse", &mha_fwd_sparse, "Forward sparse pass"); m.def("fwd_sparse", &mha_fwd_sparse, "Forward sparse pass");
m.def("fwd_sparse_mean_pool_fast", &mean_pool_fast, "before mha_fwd_sparse");
m.def("varlen_fwd_sparse", &mha_varlen_fwd_sparse, "Forward pass sparse (variable length)"); m.def("varlen_fwd_sparse", &mha_varlen_fwd_sparse, "Forward pass sparse (variable length)");
m.def("varlen_fwd_unified", &unified2D_attention_fwd, "Forward pass unified attn (variable length && block table)"); m.def("varlen_fwd_unified", &unified2D_attention_fwd, "Forward pass unified attn (variable length && block table)");
} }
...@@ -311,7 +311,7 @@ struct Dropout { ...@@ -311,7 +311,7 @@ struct Dropout {
for (int i = 0; i < size<1>(tensor); ++i) for (int i = 0; i < size<1>(tensor); ++i)
{ {
const int row_idx_base = block_row_start + i * block_row_stride + (threadIdx.x / 64) * 16; const int row_idx_base = block_row_start + i * block_row_stride + (threadIdx.x / 64) * 16 + lane_id % 16;
const int row_idx = row_idx_base; const int row_idx = row_idx_base;
uint2 rowcol = make_uint2(row_idx, col_idx_offset); uint2 rowcol = make_uint2(row_idx, col_idx_offset);
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset); uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
...@@ -344,7 +344,7 @@ struct Dropout { ...@@ -344,7 +344,7 @@ struct Dropout {
} }
}; };
const int lane_id = threadIdx.x % 64; const int lane_id = threadIdx.x % 64;
const int col_idx_offset = block_col_start + (threadIdx.x / 64) * 16; const int col_idx_offset = block_col_start + (threadIdx.x / 64) * 16 + lane_id % 16;
extern __shared__ char smem_[]; extern __shared__ char smem_[];
uint8_t *p_rand_8 = reinterpret_cast<uint8_t *>(smem_ + 16384); uint8_t *p_rand_8 = reinterpret_cast<uint8_t *>(smem_ + 16384);
...@@ -369,8 +369,10 @@ struct Dropout { ...@@ -369,8 +369,10 @@ struct Dropout {
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4); uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
*reinterpret_cast<uint4*>(&p_rand_8[row_ * RAND_STRIDE + col_]) = random_uint4; *reinterpret_cast<uint4*>(&p_rand_8[row_ * RAND_STRIDE + col_]) = random_uint4;
__syncthreads(); // __syncthreads();
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
#pragma unroll #pragma unroll
for (int j = 0; j < size<2>(tensor); ++j) { for (int j = 0; j < size<2>(tensor); ++j) {
#pragma unroll #pragma unroll
......
This diff is collapsed.
...@@ -384,7 +384,7 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre ...@@ -384,7 +384,7 @@ void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stre
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN; constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize; constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv); constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128 || Kernel_trans_traits::kHeadDim == 64) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv);
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize; constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
// constexpr static bool IsEvenMNConst = false; // constexpr static bool IsEvenMNConst = false;
...@@ -561,7 +561,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -561,7 +561,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") if (get_device_name() == "gfx936" || get_device_name() == "gfx938")
{ {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/128, /*kNWarps_*/4, T, 3>; using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 64 : 128, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4, using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>; /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
......
This diff is collapsed.
...@@ -770,8 +770,15 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -770,8 +770,15 @@ void run_mha_fwd_unified_dispatch(Flash_fwd_params &params, cudaStream_t stream)
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<256, kBlockM, kBlockN, 4, false, false, T, 256>;
run_flash_splitkv_fwd_16x64_unified_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream); run_flash_splitkv_fwd_16x64_unified_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} }
}else{ } else if constexpr (Headdim == 128) {
assert(false && "unified attn only supported headdim=256"); if (get_device_name() == "gfx936"||get_device_name() == "gfx938") {
assert(params.knew_ptr == nullptr && params.block_table != nullptr);
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits<128, 64, 64, 4, T, 3, 128>;
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
run_flash_splitkv_fwd_16x64_unified_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
}
} else {
assert(false && "unified attn only supported headdim=128/256");
} }
} }
...@@ -797,7 +804,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -797,7 +804,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) { if (params.seqlen_q <= 64||params.h * params.b * mblocks< 4*sm_count) {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 64, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, 256, 64, 4, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64_prefetch<Flash_fwd_kernel_16x64_prefetch_traits_dim64<Headdim, Is_dropout ? 128 : 256, 64, 4, T>, Is_dropout, Is_causal>(params, stream);
} }
} else { } else {
run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits<Headdim, 256, 64, 4, /*Is_Q_use_smem_=*/false, /*Share_K_V_smem_=*/false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd_16x64<Flash_fwd_kernel_16x64_traits<Headdim, 256, 64, 4, /*Is_Q_use_smem_=*/false, /*Share_K_V_smem_=*/false, T>, Is_dropout, Is_causal>(params, stream);
......
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_unified_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
template<typename scalar_t>
static __device__ inline void from_float(scalar_t &out ,float f){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
out=f;
}
else{
uint32_t u = *(uint32_t*)(&f);
u += 0x7fff + ((u >> 16) & 1);
// u += 0x8000;
out = u>>16;
}
}
template<typename scalar_t>
static __device__ inline float to_float(scalar_t in){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
return in;
}
else{
union{
uint32_t int32;
float fp32;
} u = {uint32_t(in) << 16};
return u.fp32;
}
}
#define Input_Type_SWITCH(SRC_DTYPE, ...) \
[&] { \
if (SRC_DTYPE == at::ScalarType::Half) { \
using scalar_t=_Float16; \
return __VA_ARGS__(); \
}else { \
using scalar_t=uint16_t; \
return __VA_ARGS__(); \
} \
}()
#define BLK_SWITCH(blk,...) \
[&] { \
if (blk==64){ \
constexpr static int BLK = 64; \
return __VA_ARGS__(); \
}else { \
constexpr static int BLK = 128; \
return __VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
template<typename scalar_t,int blocksize,int DIM,int BLK,bool has_mean>
__global__ void mean_pool_fast_kernel(scalar_t *out, const scalar_t *input,int L_BLOCKS,int b,int s,int h ,const scalar_t* mean){
int tid = threadIdx.x;
if(blockIdx.x<L_BLOCKS-1||s==L_BLOCKS*BLK){
const scalar_t* input_cur = input + blockIdx.z*s*h*DIM + blockIdx.y*DIM + (blockIdx.x*BLK+tid/16)*h*DIM + tid%16*8;
scalar_t* out_cur = out+blockIdx.z*h*L_BLOCKS*DIM + blockIdx.y*L_BLOCKS*DIM + blockIdx.x * DIM;
const scalar_t* mean_cur = has_mean? mean+blockIdx.z*h*DIM + blockIdx.y*DIM + tid%16*8:nullptr;
constexpr int n = DIM*BLK;
using half_vec= __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using float_vec= __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
__shared__ float lds_ptr[blocksize*8];
{
float_vec sum={0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f};
half_vec mean_temp;
if constexpr(has_mean){
mean_temp = *reinterpret_cast<const half_vec*>(mean_cur);
// if(tid==0)printf("mean_temp =%.5f,%.5f,%.5f,%.5f, %.5f,%.5f,%.5f,%.5f,\n", to_float(mean_temp[0]), to_float(mean_temp[1]), to_float(mean_temp[2]), to_float(mean_temp[3])
// , to_float(mean_temp[4]), to_float(mean_temp[5]), to_float(mean_temp[6]), to_float(mean_temp[7]));
}
for(int i=0;i<n;i+=blocksize*8){
half_vec temp = *reinterpret_cast<const half_vec*>(input_cur+i*h);
for(int ii=0;ii<8;ii++){
if constexpr(has_mean){
sum[ii] += to_float(temp[ii]) - to_float(mean_temp[ii]);
}
else{
sum[ii] += to_float(temp[ii]);
}
}
}
*reinterpret_cast<float_vec*>(lds_ptr+tid*8)=sum;
__syncthreads();
}
float sum=0.0f;
for(int i=0;i<8;i++){
sum+=lds_ptr[tid+DIM*i];
}
sum/=BLK;
from_float(out_cur[tid],sum);
}
else{
int s_lenth = s % BLK;
const scalar_t* input_cur = input + blockIdx.z*s*h*DIM + blockIdx.y*DIM + (blockIdx.x*BLK)*h*DIM + tid;
scalar_t* out_cur = out+blockIdx.z*h*L_BLOCKS*DIM + blockIdx.y*L_BLOCKS*DIM + blockIdx.x * DIM;
const scalar_t* mean_cur = has_mean? mean+blockIdx.z*h*DIM + blockIdx.y*DIM + tid:nullptr;
float sum=0.0f;
float mean_temp=0.0f;
if constexpr(has_mean){
mean_temp = to_float(*(mean_cur));
}
for(int i=0;i<s_lenth;i++){
scalar_t temp = *(input_cur+i*h*DIM);
if constexpr(has_mean){
sum+=(to_float(temp)-mean_temp);
}
else{
sum+=to_float(temp);
}
}
sum /= s_lenth;
from_float(out_cur[tid],sum);
}
}
at::Tensor mean_pool_fast(const at::Tensor &input,int blk,const c10::optional<at::Tensor> &mean){
//assume dim=128
int b=input.size(0);
int s=input.size(1);
int h=input.size(2);
int d=input.size(3);
int L_BLOCKS = (s + blk - 1) / blk;
auto out = torch::empty({b, h, L_BLOCKS,d}, input.options());
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(L_BLOCKS,h,b);
Input_Type_SWITCH(input.scalar_type(),[&]{
BLK_SWITCH(blk,[&]{
const scalar_t *mean_ptr = mean?reinterpret_cast<const scalar_t*>(mean.value().data_ptr()):nullptr;
BOOL_SWITCH(mean_ptr!=nullptr,has_mean,[&]{
const scalar_t *input_ptr = reinterpret_cast<const scalar_t*>(input.data_ptr());
scalar_t *out_ptr = reinterpret_cast<scalar_t*>(out.data_ptr());
mean_pool_fast_kernel<scalar_t,128,128,BLK,has_mean><<<grid,128,0,stream>>>(out_ptr,input_ptr,L_BLOCKS,b,s,h,mean_ptr);
});
});
});
return out;
}
\ No newline at end of file
...@@ -1711,6 +1711,159 @@ struct Flash_fwd_kernel_16x64_splitkv_prefetch_mla_traits : public Base { ...@@ -1711,6 +1711,159 @@ struct Flash_fwd_kernel_16x64_splitkv_prefetch_mla_traits : public Base {
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
}; };
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t,
int kStages_=1, int kHeadDimV_ = kHeadDim_, typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_16x64_splitkv_prefetch_unified_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = true;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 64;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static constexpr int kHeadDimV = kHeadDimV_;
static_assert(kBlockN % 64 == 0);
static_assert(kHeadDim % 32 == 0);
static_assert(kHeadDimV % 32 == 0);
static constexpr int kStages = kStages_;
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using MMA_Atom_Arch_16x64 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x64x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x64x32_F32BF16BF16F32_NT>
>;
using MMA_Atom_Arch_16x64_BLayout = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x64x32_F32F16F16F32_NT_BLayout>,
MMA_Atom<GFX928_16x64x32_F32BF16BF16F32_NT_BLayout>
>;
using MMA_Atom_Arch_16x32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>;
using TiledMma16x64 = TiledMMA<
MMA_Atom_Arch_16x64,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>;
using TiledMma16x64BLayout = TiledMMA<
MMA_Atom_Arch_16x64_BLayout,
Layout<Shape<Int<kNWarps>,_1,_1>>,
typename Base::ValLayoutMNK>;
using TiledMma16x32 = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
static constexpr uint32_t LayoutBlock = 64;
static constexpr uint32_t LayoutDim = 128;
using SmemLayoutAtomK = Layout<Shape<Int<kBlockN>, Int<128>>, Stride<Int<128>, _1>>;
using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomK{},Shape<Int<kBlockN>, Int<128>>{}));
using SmemLayoutK = Layout<Shape<Int<kBlockN*(128/64)>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutAtomO = decltype(composition(Swizzle<kSwizzle, 3, 3>{}, Layout<Shape<Int<8>, Int<kBlockKSmem>>,Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
using SmemLayoutAtomV = Layout<Shape<Int<16>, Int<32>>, Stride<Int<32>, _1>>;
using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape<Int<LayoutBlock>, Int<LayoutDim>>{}));
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<LayoutDim>, Int<LayoutBlock>>{}, GenRowMajor{})));
using SmemLayoutVsplit = decltype(tile_to_shape(SmemLayoutAtomV{}, Shape<Int<16>, Int<4*LayoutDim>>{}));
using SmemLayoutVtransSplit = decltype(composition(SmemLayoutVsplit{}, make_layout(Shape<Int<4*LayoutDim>, Int<16>>{}, GenRowMajor{})));
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemKSize = size(SmemLayoutKV{}) * sizeof(Element);
static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element);
// static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(Element);
// static constexpr int kSmemSize = std::max(kSmemKSize, kSmemOSize);
static constexpr int kSmemSize = kSmemKSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kNThreads == 512 ? 16 : kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
#if 1
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
#else
using GmemLayoutAtom = Layout<Shape <_64, _4>,
Stride< _4, _1>>;
#endif
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
using GmemTiledCopyQKVPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, Element>{},
GmemLayoutAtom{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t, template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t,
int kStages_=1, int kHeadDimV_ = kHeadDim_, typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> > int kStages_=1, int kHeadDimV_ = kHeadDim_, typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
......
...@@ -315,8 +315,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -315,8 +315,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_scale=*v_scale_ptr; v_scale=*v_scale_ptr;
} }
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x; const int kv_head_idx = blockIdx.x;
const int head_idx=num_queries_per_kv/mtp * kv_head_idx;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1; constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -353,13 +353,20 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -353,13 +353,20 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
q_zero.data[0]={0,0,0,0}; q_zero.data[0]={0,0,0,0};
q_zero.data[1]={0,0,0,0}; q_zero.data[1]={0,0,0,0};
scalar_t* s_q = reinterpret_cast<scalar_t*>(shared_mem); scalar_t* s_q = reinterpret_cast<scalar_t*>(shared_mem);
{
int head_offset = HEAD_SIZE*num_queries_per_kv/mtp;
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){ for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i); int qoffset=i/head_offset;
qoffset*=num_kv_heads*head_offset;
qoffset+=i%head_offset;
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+qoffset);
}
} }
__syncthreads(); __syncthreads();
for(int m=0;m<Mloop;m++){ for(int m=0;m<Mloop;m++){
for(int i=0;i<HEAD_SIZE/32;i++){
int head_idx_=rowid+16*m; int head_idx_=rowid+16*m;
for(int i=0;i<HEAD_SIZE/32;i++){
if(head_idx_<num_queries_per_kv)q_vec[m][i]=*reinterpret_cast<const half4x2*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*8); if(head_idx_<num_queries_per_kv)q_vec[m][i]=*reinterpret_cast<const half4x2*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*8);
else q_vec[m][i]=q_zero; else q_vec[m][i]=q_zero;
} }
...@@ -422,7 +429,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -422,7 +429,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else{ else{
scalar_t temp; scalar_t temp;
if (mtp>1){ if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_heads ; int casual = mtp - reuse_kv_idx * mtp / num_queries_per_kv ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY; if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
} }
from_float(temp,qk_vec[m][ii]); from_float(temp,qk_vec[m][ii]);
...@@ -643,6 +650,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -643,6 +650,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
} }
} }
{
scalar_t* out_ptr_base; scalar_t* out_ptr_base;
int out_offset; int out_offset;
if(num_partitions>1){ if(num_partitions>1){
...@@ -653,10 +661,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -653,10 +661,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
out_offset=HEAD_SIZE; out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE; out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
} }
int head_offset = num_queries_per_kv/mtp;
for(int g=0;g<reuse_group;g++){ for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows; int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){ if(reusekvid<num_queries_per_kv){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset; int out_head = reusekvid/head_offset*num_kv_heads*head_offset + reusekvid%head_offset;
scalar_t* out_ptr = out_ptr_base + out_head*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE; const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale); from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
...@@ -665,12 +675,14 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -665,12 +675,14 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
} }
if (num_partitions>1&&thread_idx < num_queries_per_kv){ if (num_partitions>1&&thread_idx < num_queries_per_kv){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx; int out_head = thread_idx/head_offset*num_kv_heads*head_offset + thread_idx%head_offset;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+out_head) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp); float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset); float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx]; *(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx]; *(max_logits+offset)=max_out[thread_idx];
} }
}
} }
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS> template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
...@@ -797,19 +809,22 @@ void paged_attention( ...@@ -797,19 +809,22 @@ void paged_attention(
int num_kv_heads = key_cache.size(1); int num_kv_heads = key_cache.size(1);
int PARTITION_SIZE=512; int PARTITION_SIZE=512;
int reusekv=get_reusekv(num_heads,num_kv_heads); int reusekv=get_reusekv(num_heads,num_kv_heads);
if(reusekv>15)PARTITION_SIZE=256;
//if seq<10,the seq is invalid
if (max_seq_len<=10||(max_seq_len>=8192&&max_seq_len==max_num_blocks_per_seq*block_size)){ if (max_seq_len<=10||(max_seq_len>=8192&&max_seq_len==max_num_blocks_per_seq*block_size)){
int meanseq = num_blocks*block_size/num_seqs+8192; int meanseq = num_blocks*block_size/num_seqs+4096;
int maxseq = 100000000/num_seqs/headsize/num_heads*64; int maxseq = 100000000/num_seqs/headsize/num_heads*64;
if(reusekv<=8) maxseq*=2; if(reusekv<16) maxseq*=2;
max_seq_len=MIN(max_num_blocks_per_seq*block_size,MIN(meanseq,maxseq)); max_seq_len=MIN(max_num_blocks_per_seq*block_size,MIN(meanseq,maxseq));
} }
int real_reuse_times = num_heads/num_kv_heads; else{
int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE); int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
if(max_num_partitions*num_seqs*num_kv_heads<=160||reusekv>15)PARTITION_SIZE=256; if(max_num_partitions*num_seqs*num_kv_heads<=160)PARTITION_SIZE=256;
if(num_seqs*num_kv_heads<=32&&max_seq_len<=32768)PARTITION_SIZE=256; if(num_seqs*num_kv_heads<=32&&max_seq_len<=32768)PARTITION_SIZE=256;
// if(max_num_partitions*num_seqs*num_kv_heads>200&&real_reuse_times<6&&max_seq_len>30000)PARTITION_SIZE=1024; }
int real_reuse_times = num_heads/num_kv_heads;
if(PA_PARTITION_SIZE!=0)PARTITION_SIZE=PA_PARTITION_SIZE; if(PA_PARTITION_SIZE!=0)PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE); int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
static float* tmp_out_ptr = nullptr; static float* tmp_out_ptr = nullptr;
constexpr int temp_out_size = 110000000; constexpr int temp_out_size = 110000000;
if(tmp_out_ptr == nullptr){ if(tmp_out_ptr == nullptr){
...@@ -881,7 +896,7 @@ void paged_attention( ...@@ -881,7 +896,7 @@ void paged_attention(
int shared_mem_size=PARTITION_SIZE*2*real_reuse_times+other_use; int shared_mem_size=PARTITION_SIZE*2*real_reuse_times+other_use;
grid.z = max_num_partitions; grid.z = max_num_partitions;
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n", if(PA_PRINT_PARAM&&static_cast<int32_t>(query.get_device())==0)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n",
(int)(sizeof(cache_t)==1),shared_mem_size,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions); (int)(sizeof(cache_t)==1),shared_mem_size,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions);
paged_attention_kernel<scalar_t,cache_t,is_e4m3,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,REUSE_KV_TIMES><<<grid,block,shared_mem_size,stream>>>( paged_attention_kernel<scalar_t,cache_t,is_e4m3,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,REUSE_KV_TIMES><<<grid,block,shared_mem_size,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (scalar_t*)query_ptr,(cache_t*) key_cache_ptr, (cache_t*)value_cache_ptr, (scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (scalar_t*)query_ptr,(cache_t*) key_cache_ptr, (cache_t*)value_cache_ptr,
......
...@@ -363,8 +363,9 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -363,8 +363,9 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
k_scale*=q_scale; k_scale*=q_scale;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x; const int kv_head_idx = blockIdx.x;
const int head_idx=num_queries_per_kv/mtp * kv_head_idx;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1; constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -397,12 +398,19 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -397,12 +398,19 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
intx4 q_vec[Mloop][HEAD_SIZE/64]; intx4 q_vec[Mloop][HEAD_SIZE/64];
q_type* s_q = reinterpret_cast<q_type*>(shared_mem); q_type* s_q = reinterpret_cast<q_type*>(shared_mem);
{
int head_offset = HEAD_SIZE*num_queries_per_kv/mtp;
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){ for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
int qoffset=i/head_offset;
qoffset*=num_kv_heads*head_offset;
qoffset+=i%head_offset;
if constexpr (q_is_fp8){ if constexpr (q_is_fp8){
*reinterpret_cast<intx2*>(s_q+i)=*reinterpret_cast<const intx2*>(q_ptr+i); *reinterpret_cast<intx2*>(s_q+i)=*reinterpret_cast<const intx2*>(q_ptr+qoffset);
} }
else{ else{
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i); *reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+qoffset);
}
} }
} }
__syncthreads(); __syncthreads();
...@@ -475,7 +483,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -475,7 +483,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else{ else{
scalar_t temp; scalar_t temp;
if (mtp>1){ if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_heads ; int casual = mtp - reuse_kv_idx * mtp / num_queries_per_kv ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY; if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
} }
from_float(temp,qk_vec[m][ii]); from_float(temp,qk_vec[m][ii]);
...@@ -680,7 +688,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -680,7 +688,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
} }
} }
{
scalar_t* out_ptr_base; scalar_t* out_ptr_base;
int out_offset; int out_offset;
if(num_partitions>1){ if(num_partitions>1){
...@@ -691,10 +699,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -691,10 +699,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
out_offset=HEAD_SIZE; out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE; out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
} }
int head_offset = num_queries_per_kv/mtp;
for(int g=0;g<reuse_group;g++){ for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows; int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){ if(reusekvid<num_queries_per_kv){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset; int out_head = reusekvid/head_offset*num_kv_heads*head_offset + reusekvid%head_offset;
scalar_t* out_ptr = out_ptr_base + out_head*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE; const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale); from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
...@@ -703,12 +713,14 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel( ...@@ -703,12 +713,14 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
} }
} }
if (num_partitions>1&&thread_idx < num_queries_per_kv){ if (num_partitions>1&&thread_idx < num_queries_per_kv){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx; int out_head = thread_idx/head_offset*num_kv_heads*head_offset + thread_idx%head_offset;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+out_head) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp); float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset); float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx]; *(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx]; *(max_logits+offset)=max_out[thread_idx];
} }
}
#endif #endif
} }
......
...@@ -675,7 +675,66 @@ __forceinline__ __device__ void gemm_rr(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tC ...@@ -675,7 +675,66 @@ __forceinline__ __device__ void gemm_rr(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tC
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ static void __ds_read_m32x16_row_col_alt(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int k_idx, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_k_rs_ds_read_m32x16_alt(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
auto shape = tCsB.shape();
constexpr int rows = get<1>(shape);
static_assert(rows == 6 || rows == 4 || rows == 3 || rows == 2);
if constexpr (rows == 6) {
__ds_read_m32x16_row_col_alt<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<3, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<4, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<5, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 4) {
__ds_read_m32x16_row_col_alt<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<2, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<3, k_idx>(tCsB, tCrB_copy_view);
} else if constexpr (rows == 3) {
__ds_read_m32x16_row_col_alt<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<1, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<2, k_idx>(tCsB, tCrB_copy_view);
}
else if constexpr (rows == 2) {
__ds_read_m32x16_row_col_alt<0, k_idx>(tCsB, tCrB_copy_view);
__ds_read_m32x16_row_col_alt<1, k_idx>(tCsB, tCrB_copy_view);
}
// cute::copy(smem_tiled_copy_B, tCsB(_, _, k_idx), tCrB_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tCrA(_, _, k_idx), tCrB(_, _, k_idx), acc);
}
template<int row, int col, typename Tensor0, typename Tensor1> template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ static void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst) __forceinline__ __device__ static void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst)
......
...@@ -353,10 +353,11 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout, ...@@ -353,10 +353,11 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
at::TensorOptions opts, at::TensorOptions opts,
at::Tensor &dropout_debug_count) { at::Tensor &dropout_debug_count) {
if (p_dropout > 0) {
rng_state = at::empty({2}, opts.dtype(at::ScalarType::Long)); rng_state = at::empty({2}, opts.dtype(at::ScalarType::Long));
// Forward kernel will populate memory with the seed and offset. // Match the generic FlashAttention API contract: rng_state is returned as a
// tensor even when dropout is disabled.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr()); params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
if (p_dropout > 0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
...@@ -371,8 +372,6 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout, ...@@ -371,8 +372,6 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout,
params.dropout_debug_count = params.dropout_debug_count =
reinterpret_cast<uint32_t *>(dropout_debug_count.data_ptr()); reinterpret_cast<uint32_t *>(dropout_debug_count.data_ptr());
#endif #endif
} else {
params.rng_state = nullptr;
} }
} }
...@@ -1637,16 +1636,11 @@ std::vector<at::Tensor> varlen_fwd_bhsd( ...@@ -1637,16 +1636,11 @@ std::vector<at::Tensor> varlen_fwd_bhsd(
params.total_k = total_k; params.total_k = total_k;
at::Tensor rng_state; at::Tensor rng_state;
if (p_dropout > 0) { auto options =
auto options = at::TensorOptions() at::TensorOptions().dtype(at::ScalarType::Float).device(at::DeviceType::CUDA);
.dtype(at::ScalarType::Float)
.device(at::DeviceType::CUDA);
rng_state = at::empty({2}, options.dtype(at::ScalarType::Long)); rng_state = at::empty({2}, options.dtype(at::ScalarType::Long));
// Forward kernel will populate memory with the seed and offset. // Keep the return tuple compatible with the generic FlashAttention path.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr()); params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
} else {
params.rng_state = nullptr;
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
...@@ -1884,16 +1878,11 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd( ...@@ -1884,16 +1878,11 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
} }
at::Tensor rng_state; at::Tensor rng_state;
if (p_dropout > 0) { auto options =
auto options = at::TensorOptions() at::TensorOptions().dtype(at::ScalarType::Float).device(at::DeviceType::CUDA);
.dtype(at::ScalarType::Float)
.device(at::DeviceType::CUDA);
rng_state = at::empty({2}, options.dtype(at::ScalarType::Long)); rng_state = at::empty({2}, options.dtype(at::ScalarType::Long));
// Forward kernel will populate memory with the seed and offset. // Keep the return tuple compatible with the generic FlashAttention path.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr()); params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
} else {
params.rng_state = nullptr;
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads); set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
......
...@@ -26,6 +26,7 @@ if torch.cuda.is_available(): ...@@ -26,6 +26,7 @@ if torch.cuda.is_available():
flash_attn_varlen_with_mask_func, flash_attn_varlen_with_mask_func,
# unified attn functions # unified attn functions
varlen_fwd_unified, varlen_fwd_unified,
fwd_sparse_mean_pool_fast,
) )
# triton fa interface # triton fa interface
from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func
......
...@@ -161,7 +161,7 @@ def _flash_attn_varlen_forward( ...@@ -161,7 +161,7 @@ def _flash_attn_varlen_forward(
# breakpoint() # breakpoint()
return out, softmax_lse, S_dmask, rng_state return out, softmax_lse, S_dmask, rng_state
@torch.library.register_fake("flash_attn2_c_op::varlen_fwd") @_torch_register_fake_wrapper("flash_attn2_c_op::varlen_fwd")
def varlen_fwd_fake( def varlen_fwd_fake(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -2008,7 +2008,7 @@ def vllm_flash_attn_varlen_func( ...@@ -2008,7 +2008,7 @@ def vllm_flash_attn_varlen_func(
# if mtp, k head must be 1. # if mtp, k head must be 1.
# todo : support k head >1 # todo : support k head >1
is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5) is_mtp = (max_seqlen_q*bs==total_q and max_seqlen_q>1 and max_seqlen_q<5)
if (max_seqlen_q==1 or (is_mtp and k.shape[1]==1)) and real_window_size[0]==-1: if (max_seqlen_q==1 or is_mtp ) and real_window_size[0]==-1:
if out==None: if out==None:
if q.dtype == torch.float8_e4m3fn or q.dtype == torch.float8_e5m2: if q.dtype == torch.float8_e4m3fn or q.dtype == torch.float8_e5m2:
out = torch.empty(q.size(),device = q.device,dtype=torch.bfloat16) out = torch.empty(q.size(),device = q.device,dtype=torch.bfloat16)
...@@ -3816,6 +3816,22 @@ def spas_fa2_attn_meansim_topk_varlen_cuda( ...@@ -3816,6 +3816,22 @@ def spas_fa2_attn_meansim_topk_varlen_cuda(
) )
def fwd_sparse_mean_pool_fast(x,BLK,mean=None):
return flash_attn_cuda.fwd_sparse_mean_pool_fast(x,BLK,mean)
def get_block_map_fast(q, k, topk_ratio, BLKQ=128, BLKK=64):
meank = torch.mean(k, dim=-3, keepdim=True)
pooled_kblocks = fwd_sparse_mean_pool_fast(k, BLKK, meank)
pooled_qblocks = fwd_sparse_mean_pool_fast(q,BLKQ)
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)
K = pooled_score.shape[-1]
topk = min(K, int(topk_ratio * K))
lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk
class SparseLinearAttention(nn.Module): class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True): def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, use_fp8=False, tie_feature_map_qk=True):
R''' R'''
...@@ -3872,19 +3888,15 @@ class SparseLinearAttention(nn.Module): ...@@ -3872,19 +3888,15 @@ class SparseLinearAttention(nn.Module):
''' '''
B, seqlen_q, H, headdim = q.shape B, seqlen_q, H, headdim = q.shape
q_bhld = q.transpose(1, 2).contiguous() # (B, H, L, D)
k_bhld = k.transpose(1, 2).contiguous()
# v_bhld = v.transpose(1, 2).contiguous()
# import pdb
# pdb.set_trace()
if headdim == 64: if headdim == 64:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128: elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128 block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64 block_k = 64
sparse_map, lut, real_topk = get_block_map(q_bhld, k_bhld, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k) if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=self.topk, BLKQ=block_m, BLKK=block_k)
q = q.to(self.dtype) q = q.to(self.dtype)
k = k.to(self.dtype) k = k.to(self.dtype)
...@@ -3900,10 +3912,10 @@ class SparseLinearAttention(nn.Module): ...@@ -3900,10 +3912,10 @@ class SparseLinearAttention(nn.Module):
seqlen_k = k.size(1) seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.zeros( column_count = torch.empty(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device (B, H, num_blocks_q), dtype=torch.int32, device=q.device
) )
column_index = torch.zeros( column_index = torch.empty(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device (B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
) )
...@@ -3972,14 +3984,56 @@ def sparse_attn_with_sla( ...@@ -3972,14 +3984,56 @@ def sparse_attn_with_sla(
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
attn = SparseLinearAttention( dtype = torch.bfloat16 if use_bf16 else torch.float16
head_dim=q.size(-1), dtype = torch.float8_e4m3fn if use_fp8 else dtype
topk=topk, # = 1 - sparsity B, seqlen_q, H, headdim = q.shape
feature_map=feature_map, # options: elu, relu, softmax assert not (use_bf16 and use_fp8), "Only one of bf16 and fp8 can be used."
use_bf16=use_bf16, assert headdim in (64, 128), "Dtype fp16/bf16 only support dim (64, 128)."
use_fp8=use_fp8, assert not (use_fp8 and headdim==64), "Dtype fp8 only support dim 128."
).cuda() if headdim == 64:
return attn(q, k, v, return_sparsity=return_sparsity) block_m = 64 if seqlen_q <= 2048 else 128
elif headdim == 128:
block_m = 64 if seqlen_q <= 2048 else 128
block_k = 64
if headdim == 64:
sparse_map, lut, real_topk = get_block_map(q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
else:
sparse_map, lut, real_topk = get_block_map_fast(q, k, topk_ratio=topk, BLKQ=block_m, BLKK=block_k)
q = q.to(dtype)
k = k.to(dtype)
v = v.to(dtype)
########## SPARGE BEGIN ##########
headdim = q.size(-1)
block_offset, block_count = block_map_to_block_offset_triton(sparse_map)
block_offset = block_offset * block_k
softmax_scale = 1.0 / (headdim ** 0.5)
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
seqlen_k = k.size(1)
num_blocks_q = (seqlen_q + block_m - 1) // block_m
num_blocks_k = (seqlen_k + block_k - 1) // block_k
column_count = torch.empty(
(B, H, num_blocks_q), dtype=torch.int32, device=q.device
)
column_index = torch.empty(
(B, H, num_blocks_q, 1), dtype=torch.int32, device=q.device
)
o_s = sparse_attn_func(
q, k, v, # Use original BLHD layout
block_count=block_count,
block_offset=block_offset,
column_count=column_count,
column_index=column_index,
softmax_scale=softmax_scale,
is_sla=True,
)
if return_sparsity:
return o_s, real_topk / sparse_map.shape[-1]
else:
return o_s
def _require_hg_varlen_symbol(name: str): def _require_hg_varlen_symbol(name: str):
......
...@@ -870,10 +870,15 @@ if not SKIP_CUDA_BUILD: ...@@ -870,10 +870,15 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp8_outfp16_e5m2_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp8_outfp16_e5m2_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_bf16_kv_e5m2_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_q_bf16_kv_e5m2_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_q_fp16_kv_e5m2_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_q_fp16_kv_e5m2_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_unified_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_sparse_util.cu"
], ],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3", "-w", "-std=c++17", "cxx": ["-O3", "-w", "-std=c++17",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment