Unverified Commit 41b611f7 authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

Add more GPU architctures support (#76)



* Add more GPU architctures support

* Merge fmha and mla runner

* add varlen & non varlen support, and add incontiguous tensor support

* update readme

* add varlen api

---------
Co-authored-by: default avatardianzhangc <dianzhangc@nvidia.com>
parent 9edee0c0
......@@ -3,4 +3,7 @@ __version__ = "1.0.0"
from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
)
......@@ -2,7 +2,9 @@ from typing import Optional, Tuple
import torch
import flash_mla_cuda
import flash_mla_sm90
import flash_mla_sm100
def get_mla_metadata(
......@@ -20,10 +22,10 @@ def get_mla_metadata(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache(
def flash_mla_with_kvcache_sm90(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
......@@ -52,7 +54,7 @@ def flash_mla_with_kvcache(
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
......@@ -64,3 +66,264 @@ def flash_mla_with_kvcache(
num_splits,
)
return out, softmax_lse
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape
mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)
if out is None:
out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype)
if lse is None:
# Make lse contiguous on seqlen dim
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_sm100.fwd(
workspace_buffer,
q,
k,
v,
cu_seqlens_qo,
cu_seqlens_kv,
out,
lse,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)
return out, lse
def _flash_attn_varlen_backward(
do: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape
# TODO: fix bwd GQA
if num_qo_heads != num_kv_heads:
raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.")
mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)
if dq is None:
dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dk is None:
dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dv is None:
dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype)
max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0
workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_sm100.bwd(
workspace_buffer,
do,
q,
k,
v,
out,
lse,
cu_seqlens_qo,
cu_seqlens_kv,
dq,
dk,
dv,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)
return dq, dk, dv
class FlashAttnVarlenFunc(torch.autograd.Function):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
):
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal=causal, softmax_scale=softmax_scale,
is_varlen=is_varlen,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv)
ctx.max_seqlen_qo = max_seqlen_qo
ctx.max_seqlen_kv = max_seqlen_kv
ctx.causal = causal
ctx.softmax_scale = softmax_scale
ctx.is_varlen = is_varlen
return out, lse
def backward(
ctx,
do: torch.Tensor,
dlse: torch.Tensor,
):
del dlse # LSE doesn't support backward currently
q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors
dq, dk, dv = _flash_attn_varlen_backward(
do, q, k, v, out, lse,
cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale,
is_varlen=ctx.is_varlen,
)
return dq, dk, dv, None, None, None, None, None, None, None
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
def flash_attn_varlen_qkvpacked_func(
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:],
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
causal, softmax_scale, is_varlen,
)
def flash_attn_varlen_kvpacked_func(
q: torch.Tensor,
kv: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:],
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
def flash_mla_with_kvcache_sm100(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
pass
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
capability = torch.cuda.get_device_capability(q.device.index)
if capability == (9, 0):
return flash_mla_with_kvcache_sm90(
q, k_cache, block_table, cache_seqlens, head_dim_v,
tile_scheduler_metadata, num_splits,
softmax_scale, causal,
)
elif capability == (10, 0):
raise ValueError(f"Unsupported device capability: {capability}")
else:
raise ValueError(f"Unsupported device capability: {capability}")
......@@ -27,9 +27,13 @@ def get_features_args():
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90a,code=sm_90a")
cc_flag_sm90 = []
cc_flag_sm90.append("-gencode")
cc_flag_sm90.append("arch=compute_90a,code=sm_90a")
cc_flag_sm100 = []
cc_flag_sm100.append("-gencode")
cc_flag_sm100.append("arch=compute_100a,code=sm_100a")
this_dir = os.path.dirname(os.path.abspath(__file__))
......@@ -41,12 +45,12 @@ else:
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_cuda",
name="flash_mla_sm90",
sources=[
"csrc/flash_api.cpp",
"csrc/kernels/get_mla_metadata.cu",
"csrc/kernels/mla_combine.cu",
"csrc/kernels/splitkv_mla.cu",
"csrc/sm90/flash_api.cpp",
"csrc/sm90/kernels/get_mla_metadata.cu",
"csrc/sm90/kernels/mla_combine.cu",
"csrc/sm90/kernels/splitkv_mla.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
......@@ -66,12 +70,49 @@ ext_modules.append(
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10"
]
+ cc_flag
+ cc_flag_sm90
) + get_features_args(),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
ext_modules.append(
CUDAExtension(
name="flash_mla_sm100",
sources=[
"csrc/sm100/pybind.cu",
"csrc/sm100/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/fmha_cutlass_bwd_sm100.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-DNDEBUG",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
]
+ cc_flag_sm100
),
},
include_dirs=[
Path(this_dir) / "csrc" / "sm100",
Path(this_dir) / "csrc" / "cutlass" / "include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
],
)
)
......
import random
import torch
from torch.utils.checkpoint import checkpoint
import triton
from flash_mla import flash_attn_varlen_func
def get_window_size(causal, window):
if window > 0:
window_size = (window - 1, 0) if causal else (window - 1, window - 1)
else:
window_size = (-1, -1)
return window_size
def get_attn_bias(s_q, s_k, causal, window):
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32)
if causal:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
if window > 0:
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window)
attn_bias.masked_fill_(temp_mask, float("-inf"))
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
return attn_bias
def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}"
def sdpa(query, key, value, attn_bias, softmax_scale=None):
key = key.repeat_interleave(h // h_k, dim=-3)
value = value.repeat_interleave(h // h_k, dim=-3)
if softmax_scale is None:
softmax_scale = query.shape[-1] ** (-0.5)
attn_weight = query @ key.transpose(-2, -1) * softmax_scale
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight.to(query.dtype) @ value, lse
def sdpa_checkpoint(*args, **kwargs):
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd):
print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}")
torch.manual_seed(0)
random.seed(0)
seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32)
seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1)
for i in range(b):
seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item())
cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32)
cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32)
total_q = seqlens_q.sum().item()
total_k = seqlens_k.sum().item()
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(),
causal, window) == 0).sum().item() for i in range(b)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q = torch.randn(total_q, h, d)
k = torch.randn(total_k, h_k, d)
v = torch.randn(total_k, h_k, dv)
grad_out = torch.randn(total_q, h, dv)
softmax_scale = (d + 100) ** (-0.5)
offst_q = total_q
offst_kv = total_k
q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype)
k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype)
v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype)
q1_with_buffer[total_q:] = q
k1_with_buffer[offst_kv:] = k
v1_with_buffer[offst_kv:] = v
q1 = q1_with_buffer[offst_q:].requires_grad_()
k1 = k1_with_buffer[offst_kv:].requires_grad_()
v1 = v1_with_buffer[offst_kv:].requires_grad_()
q2 = q.clone().requires_grad_()
k2 = k.clone().requires_grad_()
v2 = v.clone().requires_grad_()
def flash_attn():
q1.grad = k1.grad = v1.grad = None
kwargs = {}
if causal:
kwargs["causal"] = causal
if window != 0:
kwargs["window_size"] = get_window_size(causal, window)
return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs)
def torch_attn():
q2.grad = k2.grad = v2.grad = None
out = []
lse = []
for i in range(b):
OUT, LSE = sdpa_checkpoint(
q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2),
k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2),
v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2),
attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window),
softmax_scale=softmax_scale,
)
out.append(OUT.transpose(-3, -2))
lse.append(LSE.transpose(-2, -1))
out = torch.cat(out)
lse = torch.cat(lse)
return out, lse
out_flash, lse_flash = flash_attn()
out_torch, lse_torch = torch_attn()
assert_close(out_flash, out_torch, "out")
assert_close(lse_flash, lse_torch, "lse")
if has_bwd:
out_flash.backward(grad_out, retain_graph=True)
out_torch.backward(grad_out, retain_graph=True)
assert_close(q1.grad, q2.grad, "dq")
assert_close(k1.grad, k2.grad, "dk")
assert_close(v1.grad, v2.grad, "dv")
dq1 = q1.grad.clone()
dk1 = k1.grad.clone()
dv1 = v1.grad.clone()
def forward():
return flash_attn()
def backward():
q1.grad = k1.grad = v1.grad = None
out_flash.backward(grad_out, retain_graph=True)
for _ in range(5):
out, lse = forward()
assert torch.equal(out, out_flash), "out deterministic check failed!"
assert torch.equal(lse, lse_flash), "lse deterministic check failed!"
if has_bwd:
backward()
# assert torch.equal(q1.grad, dq1), "dq deterministic check failed!"
assert torch.equal(k1.grad, dk1), "dk deterministic check failed!"
assert torch.equal(v1.grad, dv1), "dv deterministic check failed!"
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# forward()
# if has_bwd:
# backward()
# print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120))
def timer(func, name):
t = triton.testing.do_bench(func, warmup=2, rep=3)
FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2)))
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}")
return t
timer(forward, "fwd")
if has_bwd:
timer(backward, "bwd")
if __name__ == "__main__":
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
b = 4
window = 0
has_bwd = False
for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]:
for varlen in [False, True]:
for (h, h_k) in [(32, 32), (32, 4)]:
if h != h_k:
has_bwd = False
else:
has_bwd = True
for (d, dv) in [(128, 128), (192, 128)]:
for causal in [False, True]:
test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment