Commit 4f83cf8f authored by Junxian's avatar Junxian
Browse files

[release] v0.0.1

parents
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231106']
cuda-version: ['11.8.0', '12.2.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# Pytorch <= 1.12 does not support Python 3.11
- torch-version: '1.12.1'
python-version: '3.11'
# Pytorch >= 2.0 only supports Python >= 3.8
- torch-version: '2.0.1'
python-version: '3.7'
- torch-version: '2.1.1'
python-version: '3.7'
- torch-version: '2.2.0.dev20231106'
python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1'
cuda-version: '12.2.0'
- torch-version: '1.13.1'
cuda-version: '12.2.0'
- torch-version: '2.0.1'
cuda-version: '12.2.0'
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.11
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Log Built Wheels
run: |
ls dist
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
bin/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# IDE-related
.idea/
# Dev
venv
block_sparse_tests/test_correctness/bwd_test/log/*
block_sparse_tests/test_correctness/fwd_test/log/*
block_sparse_tests/test_correctness/full_test/log/*
block_sparse_tests/test_correctness/bwd_test/log/
block_sparse_tests/test_correctness/fwd_test/log/
block_sparse_tests/test_correctness/full_test/log/
block_sparse_tests/test_correctness/full_test/log3.txt
block_sparse_tests/test_correctness/full_test/test
block_sparse_tests/test_correctness/full_test/test.cpp
block_sparse_tests/test_correctness/full_test/tmp.txt
block_sparse_tests/test_correctness/full_test/tmp2.txt
block_sparse_tests/test_correctness/full_test/arxiv/block_bwd.cu
block_sparse_tests/test_correctness/full_test/arxiv/block.cu
block_sparse_tests/test_correctness/full_test/arxiv/mask_compress.ipynb
block_sparse_tests/test_correctness/full_test/arxiv/tmp.py
block_sparse_tests/test_performance/fwd_test/log_0013.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_all_sparse_a11.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_all_sparse_l21.txt
block_sparse_tests/test_performance/fwd_test/log_split_mask_process_l21.txt
block_sparse_tests/test_performance/fwd_test/log_using_olog_all_dense_l21.txt
block_sparse_tests/test_performance/fwd_test/log_using_olog_l21.txt
block_sparse_tests/test_performance/fwd_test/fig/hdim128_nheads32_bts1_fwd.png
block_sparse_tests/test_performance/fwd_test/fig/hdim128_nheads32_bts1_fwd.png
block_sparse_tests/test_performance/fwd_test/fig/streaming/hdim128_nheads32_bts1_sink1_local3_fwd.png
block_sparse_tests/test_performance/fwd_test/old_fig/hdim128_nheads32_bts8_fwd.png
block_sparse_tests/test_performance/fwd_test/old_fig/hdim128_nheads32_bts1_fwd.png
*.xlsx
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
BSD 3-Clause License
Copyright (c) 2024, MIT HAN Lab
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
recursive-include csrc *.cu
recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include block_flash_attn *.cu
recursive-include block_flash_attn *.h
recursive-include block_flash_attn *.cuh
recursive-include block_flash_attn *.cpp
recursive-include block_flash_attn *.hpp
clean_dist:
rm -rf dist/*
create_dist: clean_dist
python setup.py sdist
upload_package: create_dist
twine upload dist/*
# Block Sparse Attention
As prompt lengths continue to increase, the computational and memory bandwidth demands of Large Language Models (LLMs) grow significantly, making efficient processing more challenging. However, by fully leveraging the inherent sparsity in attention patterns, we can optimize the model’s performance, effectively reducing inference costs in computation. This approach not only enhances the efficiency of LLMs but also enables them to handle longer and more complex prompts without a proportional increase in resource consumption. To this end, we introduce Block Sparse Attention, a library of sparse attention kernels that supports various sparse patterns, including streaming attention with token granularity, streaming attention with block granularity, and block-sparse attention. By incorporating these patterns, Block Sparse Attention can significantly reduce the computational costs of LLMs, thereby enhancing their efficiency and scalability.
We release the implementation of Block Sparse Attention, which is modified base on [FlashAttention](https://github.com/Dao-AILab/flash-attention) 2.4.2.
![Sparse Patterns](assets/BlockSparseMaskDemo.jpeg)
## News
- [2024/10] We release both fwd pass and bwd pass of Block Sparse Attention.
## Features
We have four patterns supported in Block Sparse Attention:
1. dense attention
Calculate the full attention matrix.
2. streaming atteniton with token granularity
Calculate the attention with a fixed number of sink tokens and local tokens. You can refer to [StreamingLLM](https://arxiv.org/abs/2309.17453) for more details.
3. streaming attention with block granularity, block_size = 128
Calculate the attention with a fixed number of sink blocks and local blocks.
4. blocksparse attention, block_size = 128
Take in a block mask and calculate the attention with the block mask.
**Importantly, we support assigning different patterns for different heads.**
You can use `head_mask_type` to specify the pattern for each head. This is a list of quiry head number of integers.
For one head, `mask_type = 0` means dense attention, `mask_type = -1` means streaming attention (either block streaming or exact streaming), and `mask_type = 1` means blocksparse attention, the head will use `basemask[mask_type - 1]` as its attention mask.
For example, if you have 8 heads and
```python
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
```
This means head0, head1 use blocksparse mask, head2 to head4 and head 6 use dense mask, and head 5 and head 7 use streaming mask.
The interface is:
```python
from block_sparse_attn import block_sparse_attn_func
block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
)
```
```python
from block_sparse_attn import block_streaming_attn_func
block_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
)
```
```python
from block_sparse_attn import token_streaming_attn_func
# bwd pass is not yet supported
token_streaming_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q, max_seqlen_k,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
)
```
## Performance
### Block Sparse Speedup
<div align=center><img src="assets/BlocksparseSpeedUp.jpeg"></div>
<div align=center><img src="assets/BlocksparseSpeedUpFwdBwd.jpeg"></div>
The figures above illustrate the speedup gained by using Block Sparse Attention in comparison to dense FlashAttention2 2.4.2. This speedup was measured on an A100 GPU, with configurations including a head dimension of 128 and 32 attention heads.
### Dense & Streaming Hybrid Speedup
[Duo Attention](https://github.com/mit-han-lab/duo-attention) introduces a hybrid mask scenario, where half of the attention heads utilize a dense mask and the other half employ a streaming mask. This pattern is also proved to be an accurate approach for LLMs inference.
<div align=center><img src="assets/StreamingHybridSpeedUpRatio.jpeg"></div>
The graph above demonstrates the performance of our kernel for this specified workload. For token-level streaming masks, we allocate 64 sink tokens and 256 local tokens. For block-level streaming masks, we allocate 1 sink block and 3 local blocks, with each block consisting of 128 tokens. Speedup results were measured on an A100 GPU, using dense FlashAttention2 as the baseline, with a head dimension of 128, 32 attention heads, and a batch size of 1.
## Installation
Requirements:
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
- Linux.
```sh
pip install packaging
pip install ninja
python setup.py install
```
Block Sparse Interface: `block_sparse_attn/block_sparse_attn_interface.py`
Block Sparse Attention currently supports:
1. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
2. Head dimension 32, 64, 128.
### Tests
To run the correctness tests:
```sh
pip install pytest
```
- For fwd only
```sh
cd ./block_sparse_tests/fwd/test_correctness
pytest full_test.py
```
- For fwd and bwd
```sh
cd ./block_sparse_tests/fwd_bwd/test_correctness
pytest full_test.py
```
To run the performance tests:
- For fwd only
```sh
cd ./block_sparse_tests/fwd/test_performance/
python token_streaming.py
python blocksparse.py
```
- For fwd and bwd
```sh
cd ./block_sparse_tests/fwd_bwd/test_performance/
python block_streaming.py
python blocksparse.py
```
## Team
- Junxian Guo, developer
## Acknowledgement
- [FlashAttention](https://github.com/Dao-AILab/flash-attention): the codebase we built upon. Thanks for their wonderful work. The design of block sparse attention in FlashAttention v1.0 is very inspiring.
- [FlashAttention](https://arxiv.org/abs/2205.14135), [FlashAttention-2](https://arxiv.org/abs/2307.08691), [Big Bird](https://arxiv.org/abs/2007.14062), [ETC](https://arxiv.org/abs/2004.08483): get the idea of block sparse attention and how it can be implemented.
- [StreamingLLM](https://arxiv.org/abs/2309.17453): get the idea of streaming attention.
- [Duo Attention](https://github.com/mit-han-lab/duo-attention), [MInference 1.0](https://arxiv.org/abs/2407.02490): get the idea of hybrid masks.
## Citation
__version__ = "0.0.1"
from block_sparse_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
from block_sparse_attn.block_sparse_attn_interface import (
block_sparse_attn_func,
token_streaming_attn_func,
block_streaming_attn_func,
)
\ No newline at end of file
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# Get from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
import block_sparse_attn_cuda
import torch
import torch.nn as nn
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def convert_blockmask_row_reverse(blockmask, causal=False):
# assert not causal
# nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
# print("nonzero_idx: ", nonzero_idx)
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
# print("nonzero_idx: ", nonzero_idx)
return nonzero_idx.contiguous().to(dtype=torch.int32)
def convert_blockmask_col_reverse(blockmask, causal=False):
# assert not causal
# nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
return nonzero_idx.contiguous().to(dtype=torch.int32)
def replace_ones_with_count(tensor):
ones_mask = tensor == 1
ones_num = ones_mask.sum()
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
count = count * ones_mask
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
return tensor, ones_num
def _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right
):
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right,
None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def _block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
rng_state=None,
):
dq, dk, dv, softmax_d = block_sparse_attn_cuda.bwd_block(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
None, rng_state
)
return dq, dk, dv, softmax_d
class BlockSparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx,
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right, deterministic=False):
# Save rng_state because the backward pass will regenerate the dropout mask
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if base_blockmask is not None:
row_blockmask = convert_blockmask_row_reverse(base_blockmask, is_causal)
else:
row_blockmask = None
if exact_streaming:
assert streaming_info is not None
assert is_causal
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax=False,
window_size_left=window_size_left,
window_size_right=window_size_right
)
ctx.save_for_backward(q, k, v,
out, S_dmask, softmax_lse,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
rng_state)
# ctx.is_blocksparse = is_blocksparse
ctx.m_block_dim = m_block_dim
ctx.n_block_dim = n_block_dim
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.max_seqlen_q_ = max_seqlen_q_
ctx.max_seqlen_k_ = max_seqlen_k_
ctx.p_dropout = p_dropout
ctx.softmax_scale = softmax_scale
ctx.is_causal = is_causal
ctx.exact_streaming = exact_streaming
ctx.deterministic = deterministic
return out
@staticmethod
def backward(ctx, dout):
q, k, v, out, S_dmask, softmax_lse, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, base_blockmask, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# S_dmask is None, temporarily use another tensor just to get it running
if base_blockmask is not None:
col_blockmask = convert_blockmask_col_reverse(base_blockmask, ctx.is_causal)
else:
col_blockmask = None
assert not ctx.exact_streaming, "Exact streaming not supported in backward pass"
_block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
ctx.m_block_dim, ctx.n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
ctx.max_seqlen_q_, ctx.max_seqlen_k_,
ctx.p_dropout,
ctx.softmax_scale,
True, # zero_tensors
ctx.is_causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
rng_state=rng_state
)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class BlockSparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx,
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right,
deterministic=False):
# Save rng_state because the backward pass will regenerate the dropout mask
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if base_blockmask is not None:
row_blockmask = convert_blockmask_row_reverse(base_blockmask, is_causal)
else:
row_blockmask = None
if exact_streaming:
assert streaming_info is not None
print("is_causal: ", is_causal)
assert is_causal
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax=return_softmax and p_dropout > 0,
window_size_left=window_size_left,
window_size_right=window_size_right,
)
ctx.save_for_backward(q, k, v,
out, softmax_lse,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
rng_state)
# ctx.is_blocksparse = is_blocksparse
ctx.m_block_dim = m_block_dim
ctx.n_block_dim = n_block_dim
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.max_seqlen_q_ = max_seqlen_q_
ctx.max_seqlen_k_ = max_seqlen_k_
ctx.p_dropout = p_dropout
ctx.softmax_scale = softmax_scale
ctx.is_causal = is_causal
ctx.exact_streaming = exact_streaming
ctx.deterministic = deterministic
return out, softmax_lse, S_dmask
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, head_mask_type, streaming_info, base_blockmask, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
# S_dmask is None, temporarily use another tensor just to get it running
if base_blockmask is not None:
col_blockmask = convert_blockmask_col_reverse(base_blockmask, ctx.is_causal)
else:
col_blockmask = None
assert not ctx.exact_streaming, "Exact streaming not supported in backward pass"
dq, dk, dv, _ = _block_sparse_attn_backward(
dout,
q, k, v,
out,
softmax_lse,
dq, dk, dv,
cu_seqlens_q, cu_seqlens_k,
ctx.m_block_dim, ctx.n_block_dim,
head_mask_type,
streaming_info,
col_blockmask,
ctx.max_seqlen_q_, ctx.max_seqlen_k_,
ctx.p_dropout,
ctx.softmax_scale,
True, # zero_tensors
ctx.is_causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
rng_state=rng_state
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
):
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
if base_blockmask is not None:
assert base_blockmask.shape[1] == blocksparse_head_num
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_attn_probs,
-1, -1,
deterministic
)
def token_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q_, max_seqlen_k_,
deterministic=False,
softmax_scale=None,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
None,
max_seqlen_q_, max_seqlen_k_,
0.0,
softmax_scale,
True,
True,
return_attn_probs,
-1, -1,
deterministic
)
def block_streaming_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=True,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation"""
# print("is_causal0: ", is_causal)
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128,
head_mask_type,
streaming_info,
None,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
False,
return_attn_probs,
-1, -1,
deterministic
)
\ No newline at end of file
This diff is collapsed.
[tool.black]
line-length = 100
target-version = ['py38']
\ No newline at end of file
# Copyright (c) 2023, Tri Dao.
""" Useful functions for writing test code. """
import torch
import torch.utils.benchmark as benchmark
def benchmark_forward(
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward pass")
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_backward(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
if verbose:
print(desc, "- Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError("Grad shape does not match output shape")
def f(*inputs, y, grad):
# Set .grad to None to avoid extra operation of gradient accumulation
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt="f(*inputs, y=y, grad=grad)",
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_combined(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
if verbose:
print(desc, "- Forward + Backward pass")
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError("Grad shape does not match output shape")
def f(grad, *inputs, **kwinputs):
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
y = fn(*inputs, **kwinputs)
if type(y) is tuple:
y = y[0]
y.backward(grad, retain_graph=True)
t = benchmark.Timer(
stmt="f(grad, *inputs, **kwinputs)",
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
def benchmark_fwd_bwd(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def benchmark_all(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_combined(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def pytorch_profiler(
fn,
*inputs,
trace_filename=None,
backward=False,
amp=False,
amp_dtype=torch.float16,
cpu=False,
verbose=True,
**kwinputs,
):
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
if backward:
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
g = torch.randn_like(out)
for _ in range(30): # Warm up
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
# Backward should be done outside autocast
if backward:
out.backward(g, retain_graph=True)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
torch.profiler.ProfilerActivity.CUDA
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
# profile_memory=True,
with_stack=True,
) as prof:
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
out = fn(*inputs, **kwinputs)
if type(out) is tuple:
out = out[0]
if backward:
out.backward(g, retain_graph=True)
if verbose:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
if trace_filename is not None:
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
fn(*inputs, **kwinputs)
torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
if verbose:
print(f"{desc} max memory: {mem}GB")
torch.cuda.empty_cache()
return mem
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/tests/test_flash_attn.py
import pytest
import torch
from einops import repeat
from block_sparse_attn import (
block_sparse_attn_func,
)
from utils import (
generate_random_padding_mask,
generate_base_sparsity_mask,
generate_qkv,
generate_streaming_mask,
prepare_mixed_exact_mask,
prepare_mixed_mask,
convert_flash_attn_S_to_softmax,
normalize_flash_attn_S,
get_dropout_fraction,
attention_blocksparse_ref
)
MAX_HEADDIM_SM8x = 192
block_size = 128
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("d", [32, 64, 128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
@pytest.mark.parametrize(
"causal, exact_streaming, sink_num, local_num",
[
(True, True, 1, 3),
(True, True, 64, 256),
(True, False, 1, 3),
(False, False, 1, 3),
]
)
@pytest.mark.parametrize("p_dropout", [0.17, 0.0])
@pytest.mark.parametrize("sparsity", [0, 0.1, 0.3, 0.7, 1.0])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("nheads", [16, 32])
def test_flash_attn_varlen_block_output(
seqlen_q, seqlen_k, d, p_dropout, causal, exact_streaming, sink_num, local_num, mha_type, dtype, sparsity, batch_size, nheads
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = "cuda:0"
# set seed
torch.random.manual_seed(42)
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 8)
assert nheads % nheads_k == 0
window_size = (-1, -1)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
alibi_slopes, attn_bias = None, None
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
num_streaming_heads = nheads // 3
num_blocksparse_heads = nheads // 3
num_dense_heads = nheads - num_streaming_heads - num_blocksparse_heads
sparsity_list = [sparsity] * num_blocksparse_heads
head_mask_type = torch.tensor([0] * num_dense_heads + [1] * num_blocksparse_heads + [-1] * num_streaming_heads, device=device, dtype=torch.int32)
base_blockmask = generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, block_size, block_size, block_size, batch_size, num_blocksparse_heads, sparsity_list, causal = causal, device=device)
streaming_info = torch.tensor([sink_num, local_num] * nheads, device=device, dtype=torch.int32)
streaming_mask = generate_streaming_mask(max_seqlen_q, max_seqlen_k, batch_size, nheads, cu_seqlens_q, cu_seqlens_k, block_size, block_size, block_size, streaming_info, causal=causal, device=device)
if exact_streaming:
assert causal
print(f"exact_streaming: {exact_streaming}")
if exact_streaming:
mixed_mask = prepare_mixed_exact_mask(base_blockmask, streaming_info, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], query_padding_mask, key_padding_mask, device=device)
else:
mixed_mask = prepare_mixed_mask(base_blockmask, streaming_mask, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], device=device)
out_unpad, sm_lse, S_dmask = block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=True,
softmax_scale=None,
is_causal=causal,
exact_streaming=exact_streaming,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if p_dropout > 0.0:
assert S_dmask is not None
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
query_padding_mask,
key_padding_mask,
attn_bias,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
causal=causal,
window_size=window_size,
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment