Commit 498cd8c3 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

flash-attn -> vllm-flash-attn

parent ae856f3a
......@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn"
PACKAGE_NAME = "vllm_flash_attn"
BASE_WHEEL_URL = (
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
......@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD:
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
check_if_cuda_home_none("flash_attn")
check_if_cuda_home_none(PACKAGE_NAME)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
if CUDA_HOME is not None:
......@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD:
torch._C._GLIBCXX_USE_CXX11_ABI = True
ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
name="vllm_flash_attn_2_cuda",
sources=[
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
......@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD:
def get_package_version():
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
......@@ -225,29 +225,6 @@ def get_package_version():
return str(public_version)
def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
......@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel):
if FORCE_BUILD:
return super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
class NinjaBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
......@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension):
setup(
name=PACKAGE_NAME,
name="vllm-flash-attn",
version=get_package_version(),
packages=find_packages(
exclude=(
......@@ -315,15 +270,13 @@ setup(
"dist",
"docs",
"benchmarks",
"flash_attn.egg-info",
f"{PACKAGE_NAME}.egg-info",
)
),
author="Tri Dao",
author_email="trid@cs.stanford.edu",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Dao-AILab/flash-attention",
author="vLLM Team",
description="Forward-only flash-attn",
long_description="Forward-only flash-attn package built for PyTorch 2.1.2 and CUDA 12.1",
url="https://github.com/vllm-project/flash-attention.git",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
......@@ -335,14 +288,7 @@ setup(
else {
"bdist_wheel": CachedWheelsCommand,
},
python_requires=">=3.7",
install_requires=[
"torch",
"einops",
"packaging",
"ninja",
],
setup_requires=[
"psutil"
],
)
\ No newline at end of file
python_requires=">=3.8",
install_requires=["torch == 2.1.2"],
setup_requires=["psutil"],
)
__version__ = "2.5.6"
from flash_attn.flash_attn_interface import (
from vllm_flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
......
......@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
import vllm_flash_attn_2_cuda as flash_attn_cuda
# isort: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment