Commit 5248d7d2 authored by hly's avatar hly
Browse files

Import latest aicc hipcc fp8 pa snapshot.

Source: feature/aicc-hipcc-unified-attn-fp8-pa @ fc89765
parent c2a1b310
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.0.1'
python-version: '3.12'
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '2.0.1'
cuda-version: '12.3.2'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.14
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
......@@ -28,6 +28,7 @@ venv
# benchmarks/
*.log
test_results/*.csv
# tests/
csrc/*/*.hip
......
import math
import time
import pytest
import torch
import random
import torch.nn.functional as F
import csv
from einops import rearrange, repeat
# from flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
from flash_attn import vllm_flash_attn_with_kvcache as _flash_attn_with_kvcache
max_seqlen=8192*5
# max_seqlen=4352
eager=True
# eager=False
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores = scores / softcap
scores = scores.tanh()
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
has_leftpad,
paged_kv_block_size,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
alibi,
new_kv,
dtype,
batch_size,
qhead,
kv_head,
prof=False,
):
# if seqlen_q > seqlen_k and new_kv:
# pytest.skip()
# if not new_kv and rotary_fraction > 0.0:
# pytest.skip()
# if has_batch_idx and paged_kv_block_size is not None:
# pytest.skip()
# if has_leftpad and paged_kv_block_size is not None:
# pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
# batch_size = 64
# nheads = 32
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, qhead, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
nheads_k = kv_head
# alloc k v
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
# 生成kvcache
if paged_kv_block_size is None:
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None
else:
(
k_cache,
v_cache,
block_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
)
seq_lens = [seqlen_k for _ in range(batch_size)]
cache_seqlens = torch.tensor(seq_lens, dtype=torch.int, device=device)
if has_leftpad:
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
for i in range(batch_size)])
else:
cache_leftpad = None
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_leftpad:
key_padding_mask = torch.logical_and(
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
)
if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size
]
else:
cache_batch_idx = None
alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads_k // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads_k // nheads_k)
q_scale = torch.tensor([0.5], dtype=torch.float32,device=device)
k_scale = torch.tensor([0.5], dtype=torch.float32,device=device)
v_scale = torch.tensor([0.25], dtype=torch.float32,device=device)
# new_type = torch.float8_e5m2
# new_type = torch.float8_e4m3fn
new_type = dtype
k_cache_paged = k_cache_paged.permute(0, 2, 1, 3).contiguous().to(new_type)
v_cache_paged = v_cache_paged.permute(0, 2, 3, 1).contiguous().to(new_type)
max_seqlen_k=seqlen_k
# max_seqlen_k=32768
# warm
for i in range(10):
out = _flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
cache_seqlens=cache_seqlens,
block_table=block_table,
causal=causal,
max_seqlen_k=max_seqlen_k,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
)
# prof time
torch.cuda.synchronize()
repeat_num = 100
start_time = time.time()
for i in range(repeat_num):
out = _flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
cache_seqlens=cache_seqlens,
block_table=block_table,
causal=causal,
max_seqlen_k=max_seqlen_k,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
)
torch.cuda.synchronize()
end_time = time.time()
fc1_espl = end_time - start_time
DCU_time = fc1_espl *1000*1000 / repeat_num
IO_bytes = batch_size*seqlen_k*kv_head*d*2*k_cache_paged.element_size() #kv cache size to read
IO_bytes += batch_size*qhead*d*q.element_size() #q size to read
IO_bytes += (seqlen_k//512+1)*batch_size*qhead*d*2*2 # temp to write and read
IO_bytes += batch_size*qhead*d*2 #output to write
IO_speed = IO_bytes/DCU_time/1024/1024/1024*1000*1000
print('FA_kvcache bs=', batch_size,' seqlen=',seqlen_k,' qhead=',qhead, ' kv_head=',kv_head, ' time is', '{:.2f}'.format(DCU_time), 'us Bandwidth=','{:.2f}'.format(IO_speed),'GB/s')
res_list = [paged_kv_block_size, batch_size, seqlen_k, d, qhead, kv_head, DCU_time,IO_speed]
# print('FA_kvcache bs=', batch_size,' seqlen=',seqlen_k,' qhead=',qhead, ' kv_head=',kv_head, ' time is', '{:.2f}'.format(DCU_time), 'us')
# res_list = [paged_kv_block_size, batch_size, seqlen_k, d, qhead, kv_head, DCU_time]
return res_list
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
if paged_kv_block_size is None:
k_cache_select = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else:
k_cache_select = rearrange(
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
mult = 3 if not alibi else 5
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
num_blocks = 50000
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
if eager:
max_num_blocks_per_seq = (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size
else:
max_num_blocks_per_seq = (max_seqlen + paged_kv_block_size - 1) // paged_kv_block_size
block_tables = []
for _ in range(batch_size):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
# # randperm torch.randperm
# block_table = rearrange(
# torch.randperm(batch_size*max_seqlen//paged_kv_block_size, dtype=torch.int32, device=device),
# "(b nblocks) -> b nblocks",
# b=batch_size,
# )
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_tables.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_tables.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_tables, k_cache_paged, v_cache_paged, num_blocks
# mha
if __name__ == "__main__":
# HIP_VISIBLE_DEVICES=6 python test_kvcache.py
#config = [(1,16,16),(1,32,32),(1,32,4),(64,32,4),(1,52,4),(64,52,4),(1,16,2),(64,16,2),(1,26,2),(64,26,2),(1,8,1),(64,8,1),(1,13,1),(64,13,1)]
# config = [(120,6,1),(120,8,1),(120,28,4),(120,16,2),(120,20,4)]
# seq_lens=[600,1200,2400,4800]
random.seed(0)
torch.random.manual_seed(0)
# batchsize = [4,8,16,24,32,48,56,64,72,88,120]
# batchsize = [1,2,4,8,16,24,32,40,48,56,64,72,80,88,96,104]
batchsize = [1,8,32,128]
# batchsize = [128,256,512]
# batchsize = [16,24,32,40,48,56,64,72,80,88,96] #70B,235B
# batchsize = [24,32,40,48,56] #30B
# batchsize = [40,48,56,64,72,80,88,96] #8B
# head = [(32,2)]
# head = [(12,1)]
head = [(16,2),(32,8)]
# head = [(15,1),(16,1)]
# head = [(8,1),(9,1),(10,1),(11,1),(12,1),(13,1),(14,1),(15,1),(16,1),(17,1),(18,1),(19,1),(20,1),(21,1),(22,1),(23,1),(24,1),(25,1),(26,1),(27,1),(28,1),(29,1),(30,1),(31,1),(32,1)]
# head = [(4,1),(8,1),(12,1),(16,1),(24,1)]
# seq_lens=[100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300]
# seq_lens=[2000,2100,2200,2300,2400,2500,2600,2700]
seq_lens=[2048,8192,32768]
# seq_lens=[8192,128000]
# seq_lens=[1000,1100,1350,1500,1650,1800,2000,2300,2600,3000,3300,3500,3700,4000,4096,4100,4200,4300,4500,4700,5000]
# seq_lens=[3000,3300,3500,3800,4000,4300,4500,4800,5000]
# seq_lens=[500,700,1000,1300,2000,3000,4000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,15000,16000,18000,20000]
# seq_lens=[200,500,800,1100,1300,2000,3000,4000,5000,16000,16500,17000,17500,18000,18500,19000,19500,20000]
# seq_lens=[16000,17000,18000,19000,20000,21000]
# heads = [8, 10, 16, 18, 20, 28, 30, 32, 38, 40, 48, 50, 58, 60, 64, 68, 70]
# batchs = [64]
# seq_lens=[1500]
dtype=torch.float16
# dtype=torch.bfloat16
print(dtype)
res_time = []
for qh,kh in head:
for bs in batchsize:
for seq in seq_lens:
# if (not (seq>=10000 and bs>16)) and seq<max_seqlen:
if True:
prof_time = test_flash_attn_kvcache(
seqlen_q=1,
seqlen_k=seq, #128 512
d=128, # 64 128 160 256
has_batch_idx=False,
has_leftpad=False,
paged_kv_block_size=64, #16 256
rotary_fraction=0.0,
rotary_interleaved=False,
seqlen_new_eq_seqlen_q=True,
causal=True, # 因果注意力机制
local=False, # 局部注意力
alibi=False,
new_kv=False,
dtype=dtype,
batch_size=bs,
qhead=qh,
kv_head=kh,
prof=False # 运行单次
)
res_time.append(prof_time)
with open('kvcache_time.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
for row in res_time:
writer.writerow(row)
cutlass @ 7d49e6c7
Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc
......@@ -147,7 +147,11 @@ hg_prefix_decode_varlen_fwd(
int window_size_right,
const float softcap,
const bool return_softmax,
const int layout);
const int layout,
c10::optional<at::Tensor> scales_q_,
c10::optional<at::Tensor> scales_k_,
c10::optional<at::Tensor> scales_v_,
const bool is_bf16_output);
std::vector<at::Tensor>
hg_bwd_bshd(const at::Tensor &dout,
......@@ -965,7 +969,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
alibi_slopes_, s_aux_,
skip_softmax_threshold_scale_factor,
is_causal, seqlen_q, seqlen_k,
window_size_left, window_size_right)&&(!is_bhsd)) {
window_size_left, window_size_right)) {
if (print_param || print_hg_path) {
printf("[flash_attn] HG PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d)\n",
is_bhsd ? "bhsd" : "bshd",
......@@ -2023,7 +2027,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
if (can_use_hg_dense_bwd(
q.scalar_type(), alibi_slopes_,
head_size, head_size_value, is_causal, seqlen_q, seqlen_k,
window_size_left, window_size_right, p_dropout)&&(!is_bhsd)) {
window_size_left, window_size_right, p_dropout)) {
if (print_param || print_hg_path) {
printf("[flash_attn] HG BWD PATH layout=%s q=(%d,%d,%d,%d) k=(%d,%d,%d,%d) v=(%d,%d,%d,%d) dout=(%d,%d,%d,%d)\n",
is_bhsd ? "bhsd" : "bshd",
......@@ -4464,7 +4468,7 @@ TORCH_LIBRARY_IMPL(flash_attn2_c_op, CUDA, m) {
return std::make_tuple(results[0], results[1]);
});
}
at::Tensor mean_pool_fast(const at::Tensor &input,int blk,const c10::optional<at::Tensor> &mean);
// ============================================================================
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -4489,7 +4493,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("varlen_bwd_attnmask", &mha_varlen_bwd_attnmask, "Backward pass (variable length), with explicit attention mask");
m.def("paged_attention", &paged_attention, "Forward pass, with KV-cache");
m.def("fwd_sparse", &mha_fwd_sparse, "Forward sparse pass");
m.def("fwd_sparse_mean_pool_fast", &mean_pool_fast, "before mha_fwd_sparse");
m.def("varlen_fwd_sparse", &mha_varlen_fwd_sparse, "Forward pass sparse (variable length)");
m.def("varlen_fwd_unified", &unified2D_attention_fwd, "Forward pass unified attn (variable length && block table)");
}
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
template<typename scalar_t>
static __device__ inline void from_float(scalar_t &out ,float f){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
out=f;
}
else{
uint32_t u = *(uint32_t*)(&f);
u += 0x7fff + ((u >> 16) & 1);
// u += 0x8000;
out = u>>16;
}
}
template<typename scalar_t>
static __device__ inline float to_float(scalar_t in){
if constexpr(std::is_same<scalar_t, _Float16>::value||std::is_same<scalar_t, float>::value){
return in;
}
else{
union{
uint32_t int32;
float fp32;
} u = {uint32_t(in) << 16};
return u.fp32;
}
}
#define Input_Type_SWITCH(SRC_DTYPE, ...) \
[&] { \
if (SRC_DTYPE == at::ScalarType::Half) { \
using scalar_t=_Float16; \
return __VA_ARGS__(); \
}else { \
using scalar_t=uint16_t; \
return __VA_ARGS__(); \
} \
}()
#define BLK_SWITCH(blk,...) \
[&] { \
if (blk==64){ \
constexpr static int BLK = 64; \
return __VA_ARGS__(); \
}else { \
constexpr static int BLK = 128; \
return __VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
template<typename scalar_t,int blocksize,int DIM,int BLK,bool has_mean>
__global__ void mean_pool_fast_kernel(scalar_t *out, const scalar_t *input,int L_BLOCKS,int b,int s,int h ,const scalar_t* mean){
int tid = threadIdx.x;
if(blockIdx.x<L_BLOCKS-1||s==L_BLOCKS*BLK){
const scalar_t* input_cur = input + blockIdx.z*s*h*DIM + blockIdx.y*DIM + (blockIdx.x*BLK+tid/16)*h*DIM + tid%16*8;
scalar_t* out_cur = out+blockIdx.z*h*L_BLOCKS*DIM + blockIdx.y*L_BLOCKS*DIM + blockIdx.x * DIM;
const scalar_t* mean_cur = has_mean? mean+blockIdx.z*h*DIM + blockIdx.y*DIM + tid%16*8:nullptr;
constexpr int n = DIM*BLK;
using half_vec= __attribute__( (__vector_size__(8 * sizeof(scalar_t)) )) scalar_t;
using float_vec= __attribute__( (__vector_size__(8 * sizeof(float)) )) float;
__shared__ float lds_ptr[blocksize*8];
{
float_vec sum={0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f};
half_vec mean_temp;
if constexpr(has_mean){
mean_temp = *reinterpret_cast<const half_vec*>(mean_cur);
// if(tid==0)printf("mean_temp =%.5f,%.5f,%.5f,%.5f, %.5f,%.5f,%.5f,%.5f,\n", to_float(mean_temp[0]), to_float(mean_temp[1]), to_float(mean_temp[2]), to_float(mean_temp[3])
// , to_float(mean_temp[4]), to_float(mean_temp[5]), to_float(mean_temp[6]), to_float(mean_temp[7]));
}
for(int i=0;i<n;i+=blocksize*8){
half_vec temp = *reinterpret_cast<const half_vec*>(input_cur+i*h);
for(int ii=0;ii<8;ii++){
if constexpr(has_mean){
sum[ii] += to_float(temp[ii]) - to_float(mean_temp[ii]);
}
else{
sum[ii] += to_float(temp[ii]);
}
}
}
*reinterpret_cast<float_vec*>(lds_ptr+tid*8)=sum;
__syncthreads();
}
float sum=0.0f;
for(int i=0;i<8;i++){
sum+=lds_ptr[tid+DIM*i];
}
sum/=BLK;
from_float(out_cur[tid],sum);
}
else{
int s_lenth = s % BLK;
const scalar_t* input_cur = input + blockIdx.z*s*h*DIM + blockIdx.y*DIM + (blockIdx.x*BLK)*h*DIM + tid;
scalar_t* out_cur = out+blockIdx.z*h*L_BLOCKS*DIM + blockIdx.y*L_BLOCKS*DIM + blockIdx.x * DIM;
const scalar_t* mean_cur = has_mean? mean+blockIdx.z*h*DIM + blockIdx.y*DIM + tid:nullptr;
float sum=0.0f;
float mean_temp=0.0f;
if constexpr(has_mean){
mean_temp = to_float(*(mean_cur));
}
for(int i=0;i<s_lenth;i++){
scalar_t temp = *(input_cur+i*h*DIM);
if constexpr(has_mean){
sum+=(to_float(temp)-mean_temp);
}
else{
sum+=to_float(temp);
}
}
sum /= s_lenth;
from_float(out_cur[tid],sum);
}
}
at::Tensor mean_pool_fast(const at::Tensor &input,int blk,const c10::optional<at::Tensor> &mean){
//assume dim=128
int b=input.size(0);
int s=input.size(1);
int h=input.size(2);
int d=input.size(3);
int L_BLOCKS = (s + blk - 1) / blk;
auto out = torch::empty({b, h, L_BLOCKS,d}, input.options());
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(L_BLOCKS,h,b);
Input_Type_SWITCH(input.scalar_type(),[&]{
BLK_SWITCH(blk,[&]{
const scalar_t *mean_ptr = mean?reinterpret_cast<const scalar_t*>(mean.value().data_ptr()):nullptr;
BOOL_SWITCH(mean_ptr!=nullptr,has_mean,[&]{
const scalar_t *input_ptr = reinterpret_cast<const scalar_t*>(input.data_ptr());
scalar_t *out_ptr = reinterpret_cast<scalar_t*>(out.data_ptr());
mean_pool_fast_kernel<scalar_t,128,128,BLK,has_mean><<<grid,128,0,stream>>>(out_ptr,input_ptr,L_BLOCKS,b,s,h,mean_ptr);
});
});
});
return out;
}
\ No newline at end of file
......@@ -315,8 +315,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_scale=*v_scale_ptr;
}
const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x;
const int head_idx=num_queries_per_kv/mtp * kv_head_idx;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[];
......@@ -353,20 +353,13 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
q_zero.data[0]={0,0,0,0};
q_zero.data[1]={0,0,0,0};
scalar_t* s_q = reinterpret_cast<scalar_t*>(shared_mem);
{
int head_offset = HEAD_SIZE*num_queries_per_kv/mtp;
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
int qoffset=i/head_offset;
qoffset*=num_kv_heads*head_offset;
qoffset+=i%head_offset;
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+qoffset);
}
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i);
}
__syncthreads();
for(int m=0;m<Mloop;m++){
int head_idx_=rowid+16*m;
for(int i=0;i<HEAD_SIZE/32;i++){
int head_idx_=rowid+16*m;
if(head_idx_<num_queries_per_kv)q_vec[m][i]=*reinterpret_cast<const half4x2*>(s_q+head_idx_*HEAD_SIZE+(i*4+rows)*8);
else q_vec[m][i]=q_zero;
}
......@@ -429,7 +422,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else{
scalar_t temp;
if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_queries_per_kv ;
int casual = mtp - reuse_kv_idx * mtp / num_heads ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
}
from_float(temp,qk_vec[m][ii]);
......@@ -650,7 +643,6 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
}
{
scalar_t* out_ptr_base;
int out_offset;
if(num_partitions>1){
......@@ -661,12 +653,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
int head_offset = num_queries_per_kv/mtp;
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){
int out_head = reusekvid/head_offset*num_kv_heads*head_offset + reusekvid%head_offset;
scalar_t* out_ptr = out_ptr_base + out_head*out_offset;
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
......@@ -675,14 +665,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
if (num_partitions>1&&thread_idx < num_queries_per_kv){
int out_head = thread_idx/head_offset*num_kv_heads*head_offset + thread_idx%head_offset;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+out_head) * max_num_partitions + partition_idx;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx];
}
}
}
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
......@@ -809,22 +797,19 @@ void paged_attention(
int num_kv_heads = key_cache.size(1);
int PARTITION_SIZE=512;
int reusekv=get_reusekv(num_heads,num_kv_heads);
if(reusekv>15)PARTITION_SIZE=256;
//if seq<10,the seq is invalid
if (max_seq_len<=10||(max_seq_len>=8192&&max_seq_len==max_num_blocks_per_seq*block_size)){
int meanseq = num_blocks*block_size/num_seqs+4096;
int meanseq = num_blocks*block_size/num_seqs+8192;
int maxseq = 100000000/num_seqs/headsize/num_heads*64;
if(reusekv<16) maxseq*=2;
if(reusekv<=8) maxseq*=2;
max_seq_len=MIN(max_num_blocks_per_seq*block_size,MIN(meanseq,maxseq));
}
else{
int real_reuse_times = num_heads/num_kv_heads;
int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
if(max_num_partitions*num_seqs*num_kv_heads<=160)PARTITION_SIZE=256;
if(max_num_partitions*num_seqs*num_kv_heads<=160||reusekv>15)PARTITION_SIZE=256;
if(num_seqs*num_kv_heads<=32&&max_seq_len<=32768)PARTITION_SIZE=256;
}
int real_reuse_times = num_heads/num_kv_heads;
// if(max_num_partitions*num_seqs*num_kv_heads>200&&real_reuse_times<6&&max_seq_len>30000)PARTITION_SIZE=1024;
if(PA_PARTITION_SIZE!=0)PARTITION_SIZE=PA_PARTITION_SIZE;
int max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len,PARTITION_SIZE);
static float* tmp_out_ptr = nullptr;
constexpr int temp_out_size = 110000000;
if(tmp_out_ptr == nullptr){
......@@ -896,7 +881,7 @@ void paged_attention(
int shared_mem_size=PARTITION_SIZE*2*real_reuse_times+other_use;
grid.z = max_num_partitions;
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM&&static_cast<int32_t>(query.get_device())==0)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n",
if(PA_PRINT_PARAM)printf("is_fp8=%d,shared_mem_size=%d,HEAD_SIZE=%d,BLOCK_SIZE=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d,PARTITION_SIZE=%d,max_num_partitions=%d\n",
(int)(sizeof(cache_t)==1),shared_mem_size,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs,PARTITION_SIZE,max_num_partitions);
paged_attention_kernel<scalar_t,cache_t,is_e4m3,HEAD_SIZE,BLOCK_SIZE,NUM_THREADS,REUSE_KV_TIMES><<<grid,block,shared_mem_size,stream>>>(
(scalar_t*)out_ptr,(scalar_t*)tmp_out_ptr, (scalar_t*)query_ptr,(cache_t*) key_cache_ptr, (cache_t*)value_cache_ptr,
......
......@@ -363,9 +363,8 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
k_scale*=q_scale;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int head_idx=blockIdx.x*num_queries_per_kv;
const int kv_head_idx = blockIdx.x;
const int head_idx=num_queries_per_kv/mtp * kv_head_idx;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
constexpr int Mloop=(REUSE_KV_TIMES-1)/16+1;
extern __shared__ char shared_mem[];
......@@ -398,19 +397,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
intx4 q_vec[Mloop][HEAD_SIZE/64];
q_type* s_q = reinterpret_cast<q_type*>(shared_mem);
{
int head_offset = HEAD_SIZE*num_queries_per_kv/mtp;
for(int i=thread_idx*8;i<num_queries_per_kv*HEAD_SIZE;i+=NUM_THREADS*8){
int qoffset=i/head_offset;
qoffset*=num_kv_heads*head_offset;
qoffset+=i%head_offset;
if constexpr (q_is_fp8){
*reinterpret_cast<intx2*>(s_q+i)=*reinterpret_cast<const intx2*>(q_ptr+qoffset);
*reinterpret_cast<intx2*>(s_q+i)=*reinterpret_cast<const intx2*>(q_ptr+i);
}
else{
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+qoffset);
}
*reinterpret_cast<half4x2*>(s_q+i)=*reinterpret_cast<const half4x2*>(q_ptr+i);
}
}
__syncthreads();
......@@ -483,7 +475,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
else{
scalar_t temp;
if (mtp>1){
int casual = mtp - reuse_kv_idx * mtp / num_queries_per_kv ;
int casual = mtp - reuse_kv_idx * mtp / num_heads ;
if(token_idx+casual>seq_len)qk_vec[m][ii]=-INFINITY;
}
from_float(temp,qk_vec[m][ii]);
......@@ -688,7 +680,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
}
{
scalar_t* out_ptr_base;
int out_offset;
if(num_partitions>1){
......@@ -699,12 +691,10 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
int head_offset = num_queries_per_kv/mtp;
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<num_queries_per_kv){
int out_head = reusekvid/head_offset*num_kv_heads*head_offset + reusekvid%head_offset;
scalar_t* out_ptr = out_ptr_base + out_head*out_offset;
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = rowid+16*warp_idx + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reusekvid/16][i][g%4]*v_scale);
......@@ -713,14 +703,12 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
}
if (num_partitions>1&&thread_idx < num_queries_per_kv){
int out_head = thread_idx/head_offset*num_kv_heads*head_offset + thread_idx%head_offset;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+out_head) * max_num_partitions + partition_idx;
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
float * exp_sums=reinterpret_cast<float*>(out_tmp);
float * max_logits=reinterpret_cast<float*>(out_tmp+max_tmp_offset);
*(exp_sums+offset)=expsum_out[thread_idx];
*(max_logits+offset)=max_out[thread_idx];
}
}
#endif
}
......
......@@ -353,11 +353,10 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout,
c10::optional<at::Generator> gen_,
at::TensorOptions opts,
at::Tensor &dropout_debug_count) {
if (p_dropout > 0) {
rng_state = at::empty({2}, opts.dtype(at::ScalarType::Long));
// Match the generic FlashAttention API contract: rng_state is returned as a
// tensor even when dropout is disabled.
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
if (p_dropout > 0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
......@@ -372,6 +371,8 @@ void set_params_dropout(Flash_fwd_params &params, float p_dropout,
params.dropout_debug_count =
reinterpret_cast<uint32_t *>(dropout_debug_count.data_ptr());
#endif
} else {
params.rng_state = nullptr;
}
}
......@@ -1636,11 +1637,16 @@ std::vector<at::Tensor> varlen_fwd_bhsd(
params.total_k = total_k;
at::Tensor rng_state;
auto options =
at::TensorOptions().dtype(at::ScalarType::Float).device(at::DeviceType::CUDA);
if (p_dropout > 0) {
auto options = at::TensorOptions()
.dtype(at::ScalarType::Float)
.device(at::DeviceType::CUDA);
rng_state = at::empty({2}, options.dtype(at::ScalarType::Long));
// Keep the return tuple compatible with the generic FlashAttention path.
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
} else {
params.rng_state = nullptr;
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
......@@ -1705,10 +1711,12 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
auto q_dtype = q.dtype();
const bool int8_used = q_dtype == at::ScalarType::Char;
const bool fp8_used = q_dtype == at::ScalarType::Float8_e4m3fn;
TORCH_CHECK(q_dtype == at::ScalarType::Half ||
q_dtype == at::ScalarType::BFloat16 ||
q_dtype == at::ScalarType::Char,
"FlashAttention only support fp16 and bf16 and int8 data type");
q_dtype == at::ScalarType::Char ||
q_dtype == at::ScalarType::Float8_e4m3fn,
"FlashAttention only support fp16 and bf16 and int8 and fp8 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == at::ScalarType::Int,
......@@ -1752,6 +1760,13 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
(head_size_og == 256 and head_size_value == 256),
"Prefix prefill only supports head dimension "
"128+128/192+128/192+192/256+256");
if (fp8_used) {
TORCH_CHECK(head_size_og == 128 and head_size_value == 128,
"FP8 prefix prefill only supports head dimension 128+128 on gfx938");
TORCH_CHECK(scales_q_.has_value() && scales_k_.has_value() &&
scales_v_.has_value(),
"FP8 prefix prefill requires q/k/v descale tensors");
}
TORCH_CHECK(
num_heads % num_heads_k == 0,
"Number of heads in key/value must divide number of heads in query");
......@@ -1798,9 +1813,13 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
if (!int8_used) {
if (!int8_used && !fp8_used) {
TORCH_CHECK(out.dtype() == q_dtype,
"Output must have the same dtype as inputs");
} else if (fp8_used) {
TORCH_CHECK(out.dtype() == at::ScalarType::Half ||
out.dtype() == at::ScalarType::BFloat16,
"FP8 prefix prefill output must be fp16 or bf16");
}
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1,
......@@ -1810,7 +1829,7 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
}
} else {
// for (bs)hd layout
if (int8_used) {
if (int8_used || fp8_used) {
auto int8_opts = is_bf16_output ? opts.dtype(at::ScalarType::BFloat16)
: opts.dtype(at::ScalarType::Half);
out = at::empty({query_size[0], query_size[1], head_size_v_rounded},
......@@ -1876,13 +1895,37 @@ std::vector<at::Tensor> hg_prefix_prefill_varlen_fwd(
params.scales_q_ptr = scales_q.data_ptr();
params.total_scale_q = scales_q.numel();
}
if (fp8_used) {
params.is_bf16 = out.dtype() == at::ScalarType::BFloat16;
params.is_e4m3 = true;
at::Tensor scales_q;
scales_q = scales_q_.value();
params.q_descale_ptr = reinterpret_cast<float*>(scales_q.data_ptr());
params.q_descale_batch_stride = scales_q.stride(0);
params.q_descale_head_stride = scales_q.stride(1);
at::Tensor scales_k;
scales_k = scales_k_.value();
params.k_descale_ptr = reinterpret_cast<float*>(scales_k.data_ptr());
params.k_descale_batch_stride = scales_k.stride(0);
params.k_descale_head_stride = scales_k.stride(1);
at::Tensor scales_v;
scales_v = scales_v_.value();
params.v_descale_ptr = reinterpret_cast<float*>(scales_v.data_ptr());
params.v_descale_batch_stride = scales_v.stride(0);
params.v_descale_head_stride = scales_v.stride(1);
}
at::Tensor rng_state;
auto options =
at::TensorOptions().dtype(at::ScalarType::Float).device(at::DeviceType::CUDA);
if (p_dropout > 0) {
auto options = at::TensorOptions()
.dtype(at::ScalarType::Float)
.device(at::DeviceType::CUDA);
rng_state = at::empty({2}, options.dtype(at::ScalarType::Long));
// Keep the return tuple compatible with the generic FlashAttention path.
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t *>(rng_state.data_ptr());
} else {
params.rng_state = nullptr;
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
......@@ -3969,7 +4012,11 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
const int max_seqlen_q, const int max_seqlen_k, const float p_dropout,
const float softmax_scale, const bool zero_tensors, const bool is_causal,
int window_size_left, int window_size_right, const float softcap,
const bool return_softmax, const int layout) {
const bool return_softmax, const int layout,
c10::optional<at::Tensor> scales_q_ = c10::nullopt,
c10::optional<at::Tensor> scales_k_ = c10::nullopt,
c10::optional<at::Tensor> scales_v_ = c10::nullopt,
const bool is_bf16_output = false ) {
#if defined(BUILD_FA_KVCACHE)
const at::cuda::HIPGuardMasqueradingAsCUDA device_guard(q.device().index());
// TORCH_CHECK(is_causal == true, "For prefix decode, only causal mask = True
......@@ -3979,9 +4026,11 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
}
auto q_dtype = q.dtype();
const bool fp8_used = q_dtype == at::ScalarType::Float8_e4m3fn;
TORCH_CHECK(q_dtype == at::ScalarType::Half ||
q_dtype == at::ScalarType::BFloat16,
"For prefix decode, only support fp16 and bf16 data type");
q_dtype == at::ScalarType::BFloat16 ||
q_dtype == at::ScalarType::Float8_e4m3fn,
"For prefix decode, only support fp16/bf16/fp8_e4m3 data type");
TORCH_CHECK(k.dtype() == q_dtype,
"For prefix decode, query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype,
......@@ -4030,6 +4079,14 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
(head_size_og == 256 and head_size_value == 256),
"For prefix decode, only supports head dimension "
"128+128/192+128/192+192/256+256");
if (fp8_used) {
TORCH_CHECK((head_size_og == 128 and head_size_value == 128) or
(head_size_og == 256 and head_size_value == 256),
"For fp8 prefix decode, only supports head dimension "
"128+128/256+256 on gfx938 MLS kernel");
TORCH_CHECK(scales_q_.has_value() && scales_k_.has_value() && scales_v_.has_value(),
"For fp8 prefix decode, q/k/v descale tensors must be provided");
}
TORCH_CHECK(
num_heads % num_heads_k == 0,
"Number of heads in key/value must divide number of heads in query");
......@@ -4096,6 +4153,13 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
bool output_allocated_outside = out_.has_value();
if (output_allocated_outside) {
out = out_.value();
if (!fp8_used){
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
} else {
TORCH_CHECK(out.dtype() == at::ScalarType::Half ||
out.dtype() == at::ScalarType::BFloat16,
"For fp8 prefix decode, output must be fp16 or bf16");
}
if (out.is_contiguous()) {
out = out.view({q.size(0), q.size(1), -1});
CHECK_DEVICE(out);
......@@ -4106,8 +4170,13 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
}
} else {
// for (bs)hd layout
if (fp8_used) {
auto fp8_opts = is_bf16_output ? opts.dtype(at::ScalarType::BFloat16) : opts.dtype(at::ScalarType::Half);
out = at::empty({q.size(0), q.size(1), head_size_v_rounded}, fp8_opts);
} else {
out = at::empty({q.size(0), q.size(1), v_padded.size(-1)}, opts);
}
}
auto softmax_lse =
at::empty({num_heads * ngroups, total_q}, opts.dtype(at::kFloat));
......@@ -4138,10 +4207,30 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
params.page_block_size = page_block_size;
params.seqused_k = reinterpret_cast<int *>(seqused_k.data_ptr());
params.layout = 1; // only bshd (layout = 1) is supported yet
// params.mtp = 1; // only mtp = 1 is supported yet
params.mtp = max_seqlen_q;
params.seqlen_q *= ngroups;
params.ngroups = ngroups;
params.seqlenq_ngroups_swapped = ngroups > 1;
if (fp8_used) {
params.is_bf16 = out.dtype() == at::ScalarType::BFloat16;
params.is_e4m3 = true;
at::Tensor scales_q;
scales_q = scales_q_.value();
params.q_descale_ptr = reinterpret_cast<float*>(scales_q.data_ptr());
params.q_descale_batch_stride = scales_q.stride(0);
params.q_descale_head_stride = scales_q.stride(1);
at::Tensor scales_k;
scales_k = scales_k_.value();
params.k_descale_ptr = reinterpret_cast<float*>(scales_k.data_ptr());
params.k_descale_batch_stride = scales_k.stride(0);
params.k_descale_head_stride = scales_k.stride(1);
at::Tensor scales_v;
scales_v = scales_v_.value();
params.v_descale_ptr = reinterpret_cast<float*>(scales_v.data_ptr());
params.v_descale_batch_stride = scales_v.stride(0);
params.v_descale_head_stride = scales_v.stride(1);
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
at::Tensor softmax_lseaccum;
......@@ -4149,6 +4238,8 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
hipDeviceProp_t props;
auto hipResult = hipGetDeviceProperties(&props, 0);
params.cu_count = props.multiProcessorCount;
params.num_splits = 1;
if (getArch() >= 938) {
if (batch_size * params.h < params.cu_count / 2 and
(head_size_value == 128 or head_size_value == 64)) {
params.partition_size = PA_FIX_PARTITION;
......@@ -4162,11 +4253,13 @@ std::vector<at::Tensor> hg_prefix_decode_varlen_fwd(
at::empty({params.num_splits, num_heads * ngroups, total_q},
opts.dtype(at::kFloat));
out_accum = at::empty(
{params.num_splits, out.size(0), out.size(1), out.size(2)}, opts);
{params.num_splits, out.size(0), out.size(1), out.size(2)},
fp8_used ? out.options() : opts);
params.softmax_lseaccum_ptr =
reinterpret_cast<float *>(softmax_lseaccum.data_ptr());
params.oaccum_ptr = out_accum.data_ptr();
}
}
const char *fa_debug = std::getenv("FA_DEBUG");
if (fa_debug != nullptr) {
......
......@@ -20,7 +20,15 @@ void run_mha_fwd(Flash_fwd_params &params, hipStream_t stream, bool force_split_
}
if (params.seqused_k != nullptr) {
// Prefix prefill attention
if (!params.is_int8){
if (params.is_e4m3) {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_fp8_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
} else {
assert(false && "FP8 prefix prefill only supports head_dim=128");
}
});
} else if (!params.is_int8){
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d == 128 and params.d_value == 128) {
run_mha_fwd_prefix_prefill_<elem_type, 128, 128>(params, stream);
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment