"vscode:/vscode.git/clone" did not exist on "d726857f7e7f35aa8c1f3d031048ba6c7cb069f3"
Commit 498cd8c3 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

flash-attn -> vllm-flash-attn

parent ae856f3a
...@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn" PACKAGE_NAME = "vllm_flash_attn"
BASE_WHEEL_URL = ( BASE_WHEEL_URL = (
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
...@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD: ...@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD:
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"] generator_flag = ["-DOLD_GENERATOR_PATH"]
check_if_cuda_home_none("flash_attn") check_if_cuda_home_none(PACKAGE_NAME)
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
if CUDA_HOME is not None: if CUDA_HOME is not None:
...@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD: ...@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD:
torch._C._GLIBCXX_USE_CXX11_ABI = True torch._C._GLIBCXX_USE_CXX11_ABI = True
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="flash_attn_2_cuda", name="vllm_flash_attn_2_cuda",
sources=[ sources=[
"csrc/flash_attn/flash_api.cpp", "csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
...@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD: ...@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD:
def get_package_version(): def get_package_version():
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1)) public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
...@@ -225,29 +225,6 @@ def get_package_version(): ...@@ -225,29 +225,6 @@ def get_package_version():
return str(public_version) return str(public_version)
def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel): class CachedWheelsCommand(_bdist_wheel):
""" """
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel): ...@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel):
if FORCE_BUILD: if FORCE_BUILD:
return super().run() return super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
class NinjaBuildExtension(BuildExtension): class NinjaBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
...@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension): ...@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension):
setup( setup(
name=PACKAGE_NAME, name="vllm-flash-attn",
version=get_package_version(), version=get_package_version(),
packages=find_packages( packages=find_packages(
exclude=( exclude=(
...@@ -315,15 +270,13 @@ setup( ...@@ -315,15 +270,13 @@ setup(
"dist", "dist",
"docs", "docs",
"benchmarks", "benchmarks",
"flash_attn.egg-info", f"{PACKAGE_NAME}.egg-info",
) )
), ),
author="Tri Dao", author="vLLM Team",
author_email="trid@cs.stanford.edu", description="Forward-only flash-attn",
description="Flash Attention: Fast and Memory-Efficient Exact Attention", long_description="Forward-only flash-attn package built for PyTorch 2.1.2 and CUDA 12.1",
long_description=long_description, url="https://github.com/vllm-project/flash-attention.git",
long_description_content_type="text/markdown",
url="https://github.com/Dao-AILab/flash-attention",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License", "License :: OSI Approved :: BSD License",
...@@ -335,14 +288,7 @@ setup( ...@@ -335,14 +288,7 @@ setup(
else { else {
"bdist_wheel": CachedWheelsCommand, "bdist_wheel": CachedWheelsCommand,
}, },
python_requires=">=3.7", python_requires=">=3.8",
install_requires=[ install_requires=["torch == 2.1.2"],
"torch", setup_requires=["psutil"],
"einops", )
"packaging",
"ninja",
],
setup_requires=[
"psutil"
],
)
\ No newline at end of file
__version__ = "2.5.6" __version__ = "2.5.6"
from flash_attn.flash_attn_interface import ( from vllm_flash_attn.flash_attn_interface import (
flash_attn_func, flash_attn_func,
flash_attn_kvpacked_func, flash_attn_kvpacked_func,
flash_attn_qkvpacked_func, flash_attn_qkvpacked_func,
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off # isort: off
# We need to import the CUDA kernels after importing torch # We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda import vllm_flash_attn_2_cuda as flash_attn_cuda
# isort: on # isort: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment