Commit ea2ed886 authored by Pierce Freeman's avatar Pierce Freeman
Browse files

Refactor and clean of setup.py

parent 9fc9820a
...@@ -150,6 +150,8 @@ jobs: ...@@ -150,6 +150,8 @@ jobs:
pip install ninja packaging setuptools wheel twine pip install ninja packaging setuptools wheel twine
- name: Build core package - name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: | run: |
python setup.py sdist --dist-dir=dist python setup.py sdist --dist-dir=dist
......
...@@ -6,8 +6,10 @@ import re ...@@ -6,8 +6,10 @@ import re
import ast import ast
from pathlib import Path from pathlib import Path
from packaging.version import parse, Version from packaging.version import parse, Version
import platform
from setuptools import setup, find_packages from setuptools import setup, find_packages
from setuptools.command.install import install
import subprocess import subprocess
import urllib.request import urllib.request
...@@ -24,60 +26,29 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -24,60 +26,29 @@ with open("README.md", "r", encoding="utf-8") as fh:
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
def get_platform(): def get_platform():
""" """
Returns the platform string. Returns the platform name as used in wheel filenames.
""" """
if sys.platform.startswith('linux'): if sys.platform.startswith('linux'):
return 'linux_x86_64' return 'linux_x86_64'
elif sys.platform == 'darwin': elif sys.platform == 'darwin':
return 'macosx_10_9_x86_64' mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
return f'macosx_{mac_version}_x86_64'
elif sys.platform == 'win32': elif sys.platform == 'win32':
return 'win_amd64' return 'win_amd64'
else: else:
raise ValueError('Unsupported platform: {}'.format(sys.platform)) raise ValueError('Unsupported platform: {}'.format(sys.platform))
from setuptools.command.install import install
# @pierce - TODO: Remove for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
class CustomInstallCommand(install):
def run(self):
if os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE":
return install.run(self)
raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
#tag_name=f"v{flash_version}",
# HACK
tag_name=f"v0.0.5",
wheel_name=wheel_filename
)
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
#install.run(self)
raise ValueError
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -147,37 +118,37 @@ if not torch.cuda.is_available(): ...@@ -147,37 +118,37 @@ if not torch.cuda.is_available():
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {} cmdclass = {}
ext_modules = [] ext_modules = []
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h if not SKIP_CUDA_BUILD:
# See https://github.com/pytorch/pytorch/pull/70650 print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
generator_flag = [] TORCH_MAJOR = int(torch.__version__.split(".")[0])
torch_dir = torch.__path__[0] TORCH_MINOR = int(torch.__version__.split(".")[1])
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"] generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("flash_attn") raise_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"): if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"): if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90") cc_flag.append("arch=compute_90,code=sm_90")
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="flash_attn_cuda", name="flash_attn_cuda",
sources=[ sources=[
...@@ -217,7 +188,7 @@ ext_modules.append( ...@@ -217,7 +188,7 @@ ext_modules.append(
Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
], ],
) )
) )
def get_package_version(): def get_package_version():
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
...@@ -229,18 +200,63 @@ def get_package_version(): ...@@ -229,18 +200,63 @@ def get_package_version():
else: else:
return str(public_version) return str(public_version)
class CachedWheelsCommand(install):
"""
Installer hook to scan for existing wheels that match the current platform environment.
Falls back to building from source if no wheel is found.
"""
def run(self):
if FORCE_BUILD:
return install.run(self)
raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
)
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
install.run(self)
setup( setup(
name="flash_attn", # @pierce - TODO: Revert for official release
name="flash_attn_wheels",
version=get_package_version(), version=get_package_version(),
packages=find_packages( packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
), ),
author="Tri Dao", #author="Tri Dao",
author_email="trid@stanford.edu", #author_email="trid@stanford.edu",
# @pierce - TODO: Revert for official release
author="Pierce Freeman",
author_email="pierce@freeman.vc",
description="Flash Attention: Fast and Memory-Efficient Exact Attention", description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/HazyResearch/flash-attention", #url="https://github.com/HazyResearch/flash-attention",
url="https://github.com/piercefreeman/flash-attention",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License", "License :: OSI Approved :: BSD License",
...@@ -248,10 +264,10 @@ setup( ...@@ -248,10 +264,10 @@ setup(
], ],
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={ cmdclass={
'install': CustomInstallCommand, 'install': CachedWheelsCommand,
"build_ext": BuildExtension "build_ext": BuildExtension
} if ext_modules else { } if ext_modules else {
'install': CustomInstallCommand, 'install': CachedWheelsCommand,
}, },
python_requires=">=3.7", python_requires=">=3.7",
install_requires=[ install_requires=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment