Commit 4f83cf8f authored by Junxian's avatar Junxian
Browse files

[release] v0.0.1

parents
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231106']
cuda-version: ['11.8.0', '12.2.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# Pytorch <= 1.12 does not support Python 3.11
- torch-version: '1.12.1'
python-version: '3.11'
# Pytorch >= 2.0 only supports Python >= 3.8
- torch-version: '2.0.1'
python-version: '3.7'
- torch-version: '2.1.1'
python-version: '3.7'
- torch-version: '2.2.0.dev20231106'
python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1'
cuda-version: '12.2.0'
- torch-version: '1.13.1'
cuda-version: '12.2.0'
- torch-version: '2.0.1'
cuda-version: '12.2.0'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.11
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# IDE-related
.idea/
# Dev
venv
block_sparse_tests/test_correctness/bwd_test/log/*
block_sparse_tests/test_correctness/fwd_test/log/*
block_sparse_tests/test_correctness/full_test/log/*
block_sparse_tests/test_correctness/bwd_test/log/
block_sparse_tests/test_correctness/fwd_test/log/
block_sparse_tests/test_correctness/full_test/log/
block_sparse_tests/test_correctness/full_test/log3.txt
block_sparse_tests/test_correctness/full_test/test
block_sparse_tests/test_correctness/full_test/test.cpp
block_sparse_tests/test_correctness/full_test/tmp.txt
block_sparse_tests/test_correctness/full_test/tmp2.txt
block_sparse_tests/test_correctness/full_test/arxiv/block_bwd.cu
block_sparse_tests/test_correctness/full_test/arxiv/block.cu
block_sparse_tests/test_correctness/full_test/arxiv/mask_compress.ipynb
block_sparse_tests/test_correctness/full_test/arxiv/tmp.py
block_sparse_tests/test_performance/fwd_test/log_0013.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_all_sparse_a11.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_all_sparse_l21.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_l21.txt
block_sparse_tests/test_performance/fwd_test/log_using_olog_all_dense_l21.txt
block_sparse_tests/test_performance/fwd_test/log_using_olog_l21.txt
block_sparse_tests/test_performance/fwd_test/fig/hdim128_nheads32_bts1_fwd.png
block_sparse_tests/test_performance/fwd_test/fig/hdim128_nheads32_bts1_fwd.png
block_sparse_tests/test_performance/fwd_test/fig/streaming/hdim128_nheads32_bts1_sink1_local3_fwd.png
block_sparse_tests/test_performance/fwd_test/old_fig/hdim128_nheads32_bts8_fwd.png
block_sparse_tests/test_performance/fwd_test/old_fig/hdim128_nheads32_bts1_fwd.png
*.xlsx
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
BSD 3-Clause License
Copyright (c) 2024, MIT HAN Lab
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
recursive-include csrc *.cu
recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include block_flash_attn *.cu
recursive-include block_flash_attn *.h
recursive-include block_flash_attn *.cuh
recursive-include block_flash_attn *.cpp
recursive-include block_flash_attn *.hpp
clean_dist:
rm -rf dist/*
create_dist: clean_dist
python setup.py sdist
upload_package: create_dist
twine upload dist/*
# Block Sparse Attention
As prompt lengths continue to increase, the computational and memory bandwidth demands of Large Language Models (LLMs) grow significantly, making efficient processing more challenging. However, by fully leveraging the inherent sparsity in attention patterns, we can optimize the model’s performance, effectively reducing inference costs in computation. This approach not only enhances the efficiency of LLMs but also enables them to handle longer and more complex prompts without a proportional increase in resource consumption. To this end, we introduce Block Sparse Attention, a library of sparse attention kernels that supports various sparse patterns, including streaming attention with token granularity, streaming attention with block granularity, and block-sparse attention. By incorporating these patterns, Block Sparse Attention can significantly reduce the computational costs of LLMs, thereby enhancing their efficiency and scalability.
We release the implementation of Block Sparse Attention, which is modified base on [FlashAttention](https://github.com/Dao-AILab/flash-attention) 2.4.2.
![Sparse Patterns](assets/BlockSparseMaskDemo.jpeg)
## News
- [2024/10] We release both fwd pass and bwd pass of Block Sparse Attention.
## Features
We have four patterns supported in Block Sparse Attention:
1. dense attention
Calculate the full attention matrix.
2. streaming atteniton with token granularity
Calculate the attention with a fixed number of sink tokens and local tokens. You can refer to [StreamingLLM](https://arxiv.org/abs/2309.17453) for more details.
3. streaming attention with block granularity, block_size = 128
Calculate the attention with a fixed number of sink blocks and local blocks.
4. blocksparse attention, block_size = 128
Take in a block mask and calculate the attention with the block mask.
**Importantly, we support assigning different patterns for different heads.**
You can use `head_mask_type` to specify the pattern for each head. This is a list of quiry head number of integers.
For one head, `mask_type = 0` means dense attention, `mask_type = -1` means streaming attention (either block streaming or exact streaming), and `mask_type = 1` means blocksparse attention, the head will use `basemask[mask_type - 1]` as its attention mask.
For example, if you have 8 heads and
```python
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
```
This means head0, head1 use blocksparse mask, head2 to head4 and head 6 use dense mask, and head 5 and head 7 use streaming mask.
The interface is:
```python
from block_sparse_attn import block_sparse_attn_func
block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
)
```
```python
from block_sparse_attn import block_streaming_attn_func
block_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
)
```
```python
from block_sparse_attn import token_streaming_attn_func
# bwd pass is not yet supported
token_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
)
```
## Performance
### Block Sparse Speedup
<div align=center><img src="assets/BlocksparseSpeedUp.jpeg"></div>
<div align=center><img src="assets/BlocksparseSpeedUpFwdBwd.jpeg"></div>
The figures above illustrate the speedup gained by using Block Sparse Attention in comparison to dense FlashAttention2 2.4.2. This speedup was measured on an A100 GPU, with configurations including a head dimension of 128 and 32 attention heads.
### Dense & Streaming Hybrid Speedup
[Duo Attention](https://github.com/mit-han-lab/duo-attention) introduces a hybrid mask scenario, where half of the attention heads utilize a dense mask and the other half employ a streaming mask. This pattern is also proved to be an accurate approach for LLMs inference.
<div align=center><img src="assets/StreamingHybridSpeedUpRatio.jpeg"></div>
The graph above demonstrates the performance of our kernel for this specified workload. For token-level streaming masks, we allocate 64 sink tokens and 256 local tokens. For block-level streaming masks, we allocate 1 sink block and 3 local blocks, with each block consisting of 128 tokens. Speedup results were measured on an A100 GPU, using dense FlashAttention2 as the baseline, with a head dimension of 128, 32 attention heads, and a batch size of 1.
## Installation
Requirements:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Linux.
```sh
pip install packaging
pip install ninja
python setup.py install
```
Block Sparse Interface: `block_sparse_attn/block_sparse_attn_interface.py`
Block Sparse Attention currently supports:
1. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
2. Head dimension 32, 64, 128.
### Tests
To run the correctness tests:
```sh
pip install pytest
```
- For fwd only
```sh
cd ./block_sparse_tests/fwd/test_correctness
pytest full_test.py
```
- For fwd and bwd
```sh
cd ./block_sparse_tests/fwd_bwd/test_correctness
pytest full_test.py
```
To run the performance tests:
- For fwd only
```sh
cd ./block_sparse_tests/fwd/test_performance/
python token_streaming.py
python blocksparse.py
```
- For fwd and bwd
```sh
cd ./block_sparse_tests/fwd_bwd/test_performance/
python block_streaming.py
python blocksparse.py
```
## Team
- Junxian Guo, developer
## Acknowledgement
- [FlashAttention](https://github.com/Dao-AILab/flash-attention): the codebase we built upon. Thanks for their wonderful work. The design of block sparse attention in FlashAttention v1.0 is very inspiring.
- [FlashAttention](https://arxiv.org/abs/2205.14135), [FlashAttention-2](https://arxiv.org/abs/2307.08691), [Big Bird](https://arxiv.org/abs/2007.14062), [ETC](https://arxiv.org/abs/2004.08483): get the idea of block sparse attention and how it can be implemented.
- [StreamingLLM](https://arxiv.org/abs/2309.17453): get the idea of streaming attention.
- [Duo Attention](https://github.com/mit-han-lab/duo-attention), [MInference 1.0](https://arxiv.org/abs/2407.02490): get the idea of hybrid masks.
## Citation
__version__ = "0.0.1"
from block_sparse_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
from block_sparse_attn.block_sparse_attn_interface import (
block_sparse_attn_func,
token_streaming_attn_func,
block_streaming_attn_func,
)
\ No newline at end of file
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# Get from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
import block_sparse_attn_cuda
import torch
import torch.nn as nn
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def convert_blockmask_row_reverse(blockmask, causal=False):
# assert not causal
# nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
# print("nonzero_idx: ", nonzero_idx)
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
# print("nonzero_idx: ", nonzero_idx)
return nonzero_idx.contiguous().to(dtype=torch.int32)
def convert_blockmask_col_reverse(blockmask, causal=False):
# assert not causal
# nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
return nonzero_idx.contiguous().to(dtype=torch.int32)
def replace_ones_with_count(tensor):
ones_mask = tensor == 1
ones_num = ones_mask.sum()
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
count = count * ones_mask
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
return tensor, ones_num
def _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right
):
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right,
None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
rng_state=None,
):
dq, dk, dv, softmax_d = block_sparse_attn_cuda.bwd_block(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
None, rng_state
)
return dq, dk, dv, softmax_d
class BlockSparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx,
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right, deterministic=False):
# Save rng_state because the backward pass will regenerate the dropout mask
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if base_blockmask is not None:
row_blockmask = convert_blockmask_row_reverse(base_blockmask, is_causal)
else:
row_blockmask = None
if exact_streaming:
assert streaming_info is not None
assert is_causal
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax=False,
window_size_left=window_size_left,
window_size_right=window_size_right
)
ctx.save_for_backward(q, k, v,
out, S_dmask, softmax_lse,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
rng_state)
# ctx.is_blocksparse = is_blocksparse
ctx.m_block_dim = m_block_dim
ctx.n_block_dim = n_block_dim
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.max_seqlen_q_ = max_seqlen_q_
ctx.max_seqlen_k_ = max_seqlen_k_
ctx.p_dropout = p_dropout
ctx.softmax_scale = softmax_scale
ctx.is_causal = is_causal
ctx.exact_streaming = exact_streaming
ctx.deterministic = deterministic
return out
@staticmethod
def backward(ctx, dout):
q, k, v, out, S_dmask, softmax_lse, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, base_blockmask, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# S_dmask is None, temporarily use another tensor just to get it running
if base_blockmask is not None:
col_blockmask = convert_blockmask_col_reverse(base_blockmask, ctx.is_causal)
else:
col_blockmask = None
assert not ctx.exact_streaming, "Exact streaming not supported in backward pass"
_block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
ctx.m_block_dim, ctx.n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
ctx.max_seqlen_q_, ctx.max_seqlen_k_,
ctx.p_dropout,
ctx.softmax_scale,
True, # zero_tensors
ctx.is_causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
rng_state=rng_state
)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class BlockSparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx,
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right,
deterministic=False):
# Save rng_state because the backward pass will regenerate the dropout mask
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if base_blockmask is not None:
row_blockmask = convert_blockmask_row_reverse(base_blockmask, is_causal)
else:
row_blockmask = None
if exact_streaming:
assert streaming_info is not None
print("is_causal: ", is_causal)
assert is_causal
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax=return_softmax and p_dropout > 0,
window_size_left=window_size_left,
window_size_right=window_size_right,
)
ctx.save_for_backward(q, k, v,
out, softmax_lse,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
rng_state)
# ctx.is_blocksparse = is_blocksparse
ctx.m_block_dim = m_block_dim
ctx.n_block_dim = n_block_dim
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.max_seqlen_q_ = max_seqlen_q_
ctx.max_seqlen_k_ = max_seqlen_k_
ctx.p_dropout = p_dropout
ctx.softmax_scale = softmax_scale
ctx.is_causal = is_causal
ctx.exact_streaming = exact_streaming
ctx.deterministic = deterministic
return out, softmax_lse, S_dmask
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, base_blockmask, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# S_dmask is None, temporarily use another tensor just to get it running
if base_blockmask is not None:
col_blockmask = convert_blockmask_col_reverse(base_blockmask, ctx.is_causal)
else:
col_blockmask = None
assert not ctx.exact_streaming, "Exact streaming not supported in backward pass"
dq, dk, dv, _ = _block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
ctx.m_block_dim, ctx.n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
ctx.max_seqlen_q_, ctx.max_seqlen_k_,
ctx.p_dropout,
ctx.softmax_scale,
True, # zero_tensors
ctx.is_causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
rng_state=rng_state
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
):
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
if base_blockmask is not None:
assert base_blockmask.shape[1] == blocksparse_head_num
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_attn_probs,
-1, -1,
deterministic
)
def token_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q_, max_seqlen_k_,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
None,
max_seqlen_q_, max_seqlen_k_,
0.0,
softmax_scale,
True,
True,
return_attn_probs,
-1, -1,
deterministic
)
def block_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
None,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
False,
return_attn_probs,
-1, -1,
deterministic
)
\ No newline at end of file
# Copyright (c) 2023, Tri Dao.
# Get from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py
from typing import Optional, Union
import torch
import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import block_sparse_attn_cuda as flash_attn_cuda
# isort: on
def _get_block_size(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
assert head_dim <= 256
major, minor = torch.cuda.get_device_capability(device)
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
is_sm80 = major == 8 and minor == 0
is_sm90 = major == 9 and minor == 0
if head_dim <= 32:
return 128, 128
if head_dim <= 64:
return (128, 128) if not is_dropout else (128, 64)
elif head_dim <= 96:
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
elif head_dim <= 128:
if is_sm8x:
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
else:
return 128, (64 if not is_dropout else 32)
elif head_dim <= 160:
if is_sm8x:
return (128, 64) if not is_causal else (64, 64)
else:
return 128, 32
elif head_dim <= 192:
return (128, 64) if not is_dropout else (64, 64)
elif head_dim <= 224:
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
elif head_dim <= 256:
return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q,
k,
v,
None,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size[0],
window_size[1],
return_softmax,
None,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
return_softmax,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q,
k,
v,
None,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
return_softmax,
None,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size[0],
window_size[1],
deterministic,
None,
rng_state,
)
return dq, dk, dv, softmax_d
def _flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
window_size[0],
window_size[1],
deterministic,
None,
rng_state,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q,
kv[:, :, 0],
kv[:, :, 1],
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dkv[:, :, 0],
dkv[:, :, 1],
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
kv[:, 0],
kv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dkv[:, 0],
dkv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q,
k,
v,
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
# print("varlen: dout all 1")
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
# dout = torch.full_like(dout, 1)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(
qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenQKVPackedFunc.apply(
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
alibi_slopes,
None,
softmax_scale,
causal,
window_size[0],
window_size[1],
rotary_interleaved,
num_splits,
)
return out
[tool.black]
line-length = 100
target-version = ['py38']
\ No newline at end of file
# Copyright (c) 2023, Tri Dao.
""" Useful functions for writing test code. """
import torch
import torch.utils.benchmark as benchmark
def benchmark_forward(
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward pass")
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_backward(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if verbose:
print(desc, "- Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError("Grad shape does not match output shape")
def f(*inputs, y, grad):
# Set .grad to None to avoid extra operation of gradient accumulation
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt="f(*inputs, y=y, grad=grad)",
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_combined(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward + Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError("Grad shape does not match output shape")
def f(grad, *inputs, **kwinputs):
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt="f(grad, *inputs, **kwinputs)",
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_fwd_bwd(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def benchmark_all(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_combined(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def pytorch_profiler(
fn,
*inputs,
trace_filename=None,
backward=False,
amp=False,
amp_dtype=torch.float16,
cpu=False,
verbose=True,
**kwinputs,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if backward:
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
g = torch.randn_like(out)
for _ in range(30): # Warm up
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
# Backward should be done outside autocast
if backward:
out.backward(g, retain_graph=True)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
torch.profiler.ProfilerActivity.CUDA
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
# profile_memory=True,
with_stack=True,
) as prof:
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
if backward:
out.backward(g, retain_graph=True)
if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
if trace_filename is not None:
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
fn(*inputs, **kwinputs)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
if verbose:
print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache()
return mem
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/tests/test_flash_attn.py
import pytest
import torch
from einops import repeat
from block_sparse_attn import (
block_sparse_attn_func,
)
from utils import (
generate_random_padding_mask,
generate_base_sparsity_mask,
generate_qkv,
generate_streaming_mask,
prepare_mixed_exact_mask,
prepare_mixed_mask,
convert_flash_attn_S_to_softmax,
normalize_flash_attn_S,
get_dropout_fraction,
attention_blocksparse_ref
)
MAX_HEADDIM_SM8x = 192
block_size = 128
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("d", [32, 64, 128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
@pytest.mark.parametrize(
"causal, exact_streaming, sink_num, local_num",
[
(True, True, 1, 3),
(True, True, 64, 256),
(True, False, 1, 3),
(False, False, 1, 3),
]
)
@pytest.mark.parametrize("p_dropout", [0.17, 0.0])
@pytest.mark.parametrize("sparsity", [0, 0.1, 0.3, 0.7, 1.0])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("nheads", [16, 32])
def test_flash_attn_varlen_block_output(
seqlen_q, seqlen_k, d, p_dropout, causal, exact_streaming, sink_num, local_num, mha_type, dtype, sparsity, batch_size, nheads
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = "cuda:0"
# set seed
torch.random.manual_seed(42)
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 8)
assert nheads % nheads_k == 0
window_size = (-1, -1)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
alibi_slopes, attn_bias = None, None
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
num_streaming_heads = nheads // 3
num_blocksparse_heads = nheads // 3
num_dense_heads = nheads - num_streaming_heads - num_blocksparse_heads
sparsity_list = [sparsity] * num_blocksparse_heads
head_mask_type = torch.tensor([0] * num_dense_heads + [1] * num_blocksparse_heads + [-1] * num_streaming_heads, device=device, dtype=torch.int32)
base_blockmask = generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, block_size, block_size, block_size, batch_size, num_blocksparse_heads, sparsity_list, causal = causal, device=device)
streaming_info = torch.tensor([sink_num, local_num] * nheads, device=device, dtype=torch.int32)
streaming_mask = generate_streaming_mask(max_seqlen_q, max_seqlen_k, batch_size, nheads, cu_seqlens_q, cu_seqlens_k, block_size, block_size, block_size, streaming_info, causal=causal, device=device)
if exact_streaming:
assert causal
print(f"exact_streaming: {exact_streaming}")
if exact_streaming:
mixed_mask = prepare_mixed_exact_mask(base_blockmask, streaming_info, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], query_padding_mask, key_padding_mask, device=device)
else:
mixed_mask = prepare_mixed_mask(base_blockmask, streaming_mask, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], device=device)
out_unpad, sm_lse, S_dmask = block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=True,
softmax_scale=None,
is_causal=causal,
exact_streaming=exact_streaming,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if p_dropout > 0.0:
assert S_dmask is not None
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
query_padding_mask,
key_padding_mask,
attn_bias,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
causal=causal,
window_size=window_size,
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
\ No newline at end of file
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/tests/test_flash_attn.py
import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from block_sparse_attn.bert_padding import pad_input, unpad_input
from block_sparse_attn.flash_attn_interface import _get_block_size
torch.set_printoptions(profile="full")
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def construct_exact_streaming_mask(
seqlen_q,
seqlen_k,
sink_size, # -1 means infinite window size
local_size,
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
assert sink_size >= 0
assert local_size >= 1
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
mask = torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq, sk),
torch.logical_and(
col_idx < row_idx + sk - sq - (local_size-1), col_idx >= sink_size,
)
)
return mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, round_base, m_block_dim, n_block_dim, batch_size, num_blocksparse_heads, sparsity_list, causal=False, device="cuda"):
assert len(sparsity_list) == num_blocksparse_heads
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
base_mask = torch.zeros(batch_size, num_blocksparse_heads, nrow, ncol, device=device, dtype=torch.bool)
for batch in range(batch_size):
for head_rank in range(num_blocksparse_heads):
sparsity = sparsity_list[head_rank]
if not sparsity == 0.0 and not sparsity == 1.0:
for i in range(nrow):
idx = nrow - i - 1
if causal:
available_col_num = max(0, ncol - i)
num_one = max(1, int(sparsity * available_col_num))
base_mask[batch][head_rank][idx, torch.randperm(available_col_num)[:num_one]] = True
else:
available_col_num = ncol
num_one = max(1, int(sparsity * available_col_num))
base_mask[batch][head_rank][idx, torch.randperm(available_col_num)[:num_one]] = True
elif sparsity == 1.0:
base_mask[batch][head_rank] = torch.ones_like(base_mask[batch][head_rank])
return base_mask
def generate_streaming_mask(max_seqlen_q, max_seqlen_k, batch_size, num_heads,cu_q_len_list, cu_k_len_list, round_base, m_block_dim, n_block_dim, streaming_info, causal=False, device="cuda"):
assert len(streaming_info) == 2 * num_heads
assert len(cu_q_len_list) == batch_size + 1
assert len(cu_k_len_list) == batch_size + 1
assert round_base == m_block_dim == n_block_dim == 128
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
def ceil_div(x, y):
return (x + y - 1) // y
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
base_mask = torch.zeros(batch_size, num_heads, nrow, ncol, device=device, dtype=torch.bool)
for batch in range(batch_size):
actual_q = cu_q_len_list[batch + 1] - cu_q_len_list[batch]
actual_k = cu_k_len_list[batch + 1] - cu_k_len_list[batch]
start_row_idx = max((actual_q - actual_k) // m_block_dim, 0) if causal else 0
for head_rank in range(num_heads):
sink_block_num, local_block_num = streaming_info[head_rank * 2], streaming_info[head_rank * 2 + 1]
for i in range(start_row_idx, nrow):
if causal:
max_row_block_num = ceil_div(max(actual_k - actual_q, 0), n_block_dim) + 1 + i - start_row_idx
else:
max_row_block_num = ncol
base_mask[batch, head_rank, i, min(max(max_row_block_num - local_block_num, 0), ncol):min(max_row_block_num, ncol)] = True
base_mask[batch, head_rank, i, :sink_block_num] = True
return base_mask
def replace_ones_with_count(tensor):
ones_mask = tensor == 1
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
count = count * ones_mask
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
return tensor
def prepare_mixed_mask(base_blocksparse_mask, base_streaming_mask, head_mask_type, batch_size, num_heads, round_base, m_block_dim, n_block_dim, max_seqlen_q, max_seqlen_k, actual_seqlen_q, actual_seqlen_k, device="cuda"):
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
mixed_mask = torch.zeros(batch_size, num_heads, nrow, ncol, device=device, dtype=torch.bool)
head_mask_type = replace_ones_with_count(head_mask_type)
for head_rank in range(num_heads):
mask_type = head_mask_type[head_rank]
if mask_type == 0:
mixed_mask[:, head_rank, :, :] = torch.ones_like(mixed_mask[:, head_rank, :, :])
elif mask_type > 0:
for i in range(batch_size):
mixed_mask[i, head_rank, :, :] = base_blocksparse_mask[i][mask_type - 1]
else:
for i in range(batch_size):
mixed_mask[i, head_rank, :, :] = base_streaming_mask[i, head_rank, :, :]
mixed_mask = repeat(mixed_mask, "b h s_m s_n -> b h (s_m d_m) (s_n d_n)", d_m=m_block_dim, d_n=n_block_dim)
mixed_mask = tailor_mixedmask_for_test(mixed_mask, actual_seqlen_q, actual_seqlen_k)
mixed_mask = ~mixed_mask
return mixed_mask
def prepare_mixed_exact_mask(base_blocksparse_mask, streaming_info, head_mask_type, batch_size, num_heads, round_base, m_block_dim, n_block_dim, max_seqlen_q, max_seqlen_k, actual_seqlen_q, actual_seqlen_k, query_padding_mask,
key_padding_mask, device="cuda"):
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
mixed_mask = torch.zeros(batch_size, num_heads, nrow, ncol, device=device, dtype=torch.bool)
head_mask_type = replace_ones_with_count(head_mask_type)
for head_rank in range(num_heads):
mask_type = head_mask_type[head_rank]
if mask_type == 0:
mixed_mask[:, head_rank, :, :] = torch.ones_like(mixed_mask[:, head_rank, :, :])
elif mask_type > 0:
for i in range(batch_size):
mixed_mask[i, head_rank, :, :] = base_blocksparse_mask[i][mask_type - 1]
mixed_mask = repeat(mixed_mask, "b h s_m s_n -> b h (s_m d_m) (s_n d_n)", d_m=m_block_dim, d_n=n_block_dim)
mixed_mask = tailor_mixedmask_for_test(mixed_mask, actual_seqlen_q, actual_seqlen_k)
mixed_mask = ~mixed_mask
for head_rank in range(num_heads):
mask_type = head_mask_type[head_rank]
if mask_type < 0:
exact_streaming_mask = construct_exact_streaming_mask(
actual_seqlen_q,
actual_seqlen_k,
streaming_info[head_rank * 2],
streaming_info[head_rank * 2 + 1],
query_padding_mask,
key_padding_mask,
device=device,
)
if exact_streaming_mask.dim() == 4:
for i in range(batch_size):
mixed_mask[i, head_rank, :, :] = exact_streaming_mask[i,0,:,:]
else:
for i in range(batch_size):
mixed_mask[i, head_rank, :, :] = exact_streaming_mask
return mixed_mask
def attention_blocksparse_ref(
q, k, v,
mixed_mask,
m_block_dim, n_block_dim,
query_padding_mask=None,
key_padding_mask=None,
p_dropout=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1),
upcast=True,
reorder_ops=False,
):
# q, k, v = qkv.float().unbind(dim=2)
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# local mask
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
scores.masked_fill_(rearrange(mixed_mask, "b h t s -> b h t s"), float("-inf"))
# print("processed blockmask: ", rearrange(~base_blockmask, "h t s -> 1 h t s"))
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(torch.bitwise_or(local_mask, rearrange(mixed_mask, "b h t s -> b h t s")), dim=-1, keepdim=True), 0.0)
attention = attention.masked_fill(rearrange(mixed_mask, "b h t s -> b h t s"), 0.0)
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - p_dropout)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def convert_flash_attn_S_to_softmax(
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias.to(dtype=scores.dtype)
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(
dropout_mask,
mixed_mask,
m_block_dim, n_block_dim,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if mixed_mask is not None:
# mixed_mask = repeat(mixed_mask, "b h s_m s_n -> b h (s_m d_m) (s_n d_n)", d_m=m_block_dim, d_n=n_block_dim)
# mixed_mask = tailor_mixedmask_for_test(mixed_mask, seqlen_q, seqlen_k)
dropped.masked_fill_(rearrange(mixed_mask, "b h t s -> b h t s"), False)
valid.masked_fill_(rearrange(mixed_mask, "b h t s -> b h t s"), False)
if query_padding_mask is not None:
dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
)
dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum()
return dropped.sum() / valid.sum()
def modified_check(a, a_ref, a_pt, rel_paremeter):
assert a.shape == a_ref.shape
assert a.shape == a_pt.shape
left = (a - a_ref).abs().max().item()
right = (a_pt - a_ref).abs().max().item()
rtol = 1e-3
if not right == 0:
assert left < rel_paremeter * right or left < rtol * a_ref.abs().max().item()
else:
assert round(left, 4) == 0 or left < rtol * a_ref.abs().max().item()
def tailor_mixedmask_for_test(spanded_base_mixedmask, seqlen_q, seqlen_k):
batch_size = spanded_base_mixedmask.shape[0]
nheads = spanded_base_mixedmask.shape[1]
spanded_base_mixedmask = spanded_base_mixedmask[:, :, :seqlen_q, :seqlen_k]
pad_blockmask = torch.zeros(batch_size, nheads, seqlen_q, seqlen_k, dtype=torch.bool, device = spanded_base_mixedmask.device)
pad_blockmask[:, :, :spanded_base_mixedmask.shape[2], :spanded_base_mixedmask.shape[3]] = spanded_base_mixedmask
spanded_base_mixedmask = pad_blockmask
spanded_base_mixedmask = spanded_base_mixedmask.contiguous()
return spanded_base_mixedmask
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment