Commit 5a61cb77 authored by Tri Dao's avatar Tri Dao
Browse files

Rename src -> flash_attn

parent c41479d6
......@@ -11,7 +11,6 @@ Paper: https://arxiv.org/abs/2205.14135
To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
```
cd csrc/flash_attn
python setup.py install
```
......
......@@ -7,8 +7,8 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from src.bert_padding import unpad_input, pad_input
from src.flash_attn_interface import flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_func
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
......
......@@ -86,20 +86,34 @@ def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **k
)
def pytorch_profiler(fn, *inputs, repeats=10):
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, verbose=True):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if backward:
g = torch.randn_like(fn(*inputs))
for _ in range(10): # Warm up
with torch.autocast(device_type='cuda', enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
fn(*inputs) if not backward else fn(*inputs).backward(g)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as p:
# benchmark_forward(repeats, fn, *inputs)
fn(*inputs)
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,],
activities=[torch.profiler.ProfilerActivity.CUDA,],
record_shapes=True,
# profile_memory=True,
with_stack=True,
) as prof:
with torch.autocast(device_type='cuda', enabled=amp):
if backward:
for x in inputs:
if isinstance(x, torch.Tensor):
x.grad = None
fn(*inputs) if not backward else fn(*inputs).backward(g)
if verbose:
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
if trace_filename is not None:
prof.export_chrome_trace(trace_filename)
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
......
......@@ -4,9 +4,9 @@ import torch.nn as nn
from einops import rearrange
from src.rotary import RotaryEmbedding, RotaryEmbedding2D
from src.flash_attn_interface import flash_attn_func
from src.bert_padding import unpad_input, pad_input, index_first_axis
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
class FlashAttention(nn.Module):
......
......@@ -6,9 +6,9 @@ from einops import rearrange
import hydra
from src.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from src.flash_blocksparse_attn_interface import convert_blockmask
from src.bert_padding import unpad_input, pad_input, index_first_axis
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
class FlashBlocksparseAttention(nn.Module):
......
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys
import warnings
import os
from pathlib import Path
from setuptools import setup, find_packages
import subprocess
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
......@@ -66,8 +73,8 @@ if not torch.cuda.is_available():
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"By default, We cross-compile for Volta (compute capability 7.0), "
"Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
......@@ -75,11 +82,11 @@ if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
......@@ -95,7 +102,7 @@ torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("--flashattn")
raise_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
......@@ -108,11 +115,11 @@ ext_modules.append(
CUDAExtension(
name="flash_attn_cuda",
sources=[
"fmha_api.cpp",
"src/fmha_fprop_fp16_kernel.sm80.cu",
"src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"src/fmha_block_fprop_fp16_kernel.sm80.cu",
"src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
],
extra_compile_args={
"cxx": ["-O3"] + generator_flag,
......@@ -132,16 +139,30 @@ ext_modules.append(
),
},
include_dirs=[
this_dir,
os.path.join(this_dir, "src"),
Path(this_dir) / 'csrc' / 'flash_attn',
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
],
)
)
setup(
name="flash_attn_cuda",
name="flash_attn",
version="0.1",
description="Flash Attention",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),
author="Tri Dao",
author_email="trid@stanford.edu",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/HazyResearch/flash-attention",
classifiers=[
"Programming Language :: Python :: 3",
"License :: Apache 2.0",
"Operating System :: Linux",
],
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
python_requires=">=3.7"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment