Commit 494b2aa4 authored by Pierce Freeman's avatar Pierce Freeman
Browse files

Add notes to github action workflow

parent 8d60c373
# This workflow will upload a Python Package to Release asset
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Python Package
name: Build wheels and deploy
on:
create:
......
......@@ -57,6 +57,14 @@ To install:
pip install flash-attn
```
If you see an error about `ModuleNotFoundError: No module named 'torch'`, it's likely because of pypi's installation isolation.
To fix you can run:
```sh
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```
python setup.py install
......
__version__ = "1.0.7"
__version__ = "1.0.8"
## flash-attn-builder
Basic build utilities for flash-attn.
import os
import sys
import urllib
import setuptools.build_meta
from setuptools.command.install import install
from packaging.version import parse, Version
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
class CustomBuildBackend(setuptools.build_meta._BuildMetaBackend):
def build_wheel(self, wheel_directory, config_settings=None, metadata_directory=None):
this_file_directory = os.path.dirname(os.path.abspath(__file__))
print(f'This file is located in: {this_file_directory}')
sys.argv = [
*sys.argv[:1],
*self._global_args(config_settings),
*self._arbitrary_args(config_settings),
]
with setuptools.build_meta.no_install_setup_requires():
self.run_setup()
print("OS", os.environ["FLASH_ATTENTION_WHEEL_URL"])
print("config_settings", config_settings)
print("metadata_directory", metadata_directory)
raise ValueError
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().build_wheel(wheel_directory, config_settings, metadata_directory)
_BACKEND = CustomBuildBackend() # noqa
get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel
get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist
prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel
build_wheel = _BACKEND.build_wheel
build_sdist = _BACKEND.build_sdist
[tool.poetry]
name = "flash-attn-builder"
version = "0.1.0"
description = ""
authors = ["Pierce Freeman <pierce@freeman.vc>"]
readme = "README.md"
packages = [{include = "flash_attn_builder"}]
[tool.poetry.dependencies]
python = "^3.10"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[build-system]
requires = ["ninja", "packaging", "setuptools", "wheel"]
build-backend = "setuptools.build_meta"
......@@ -9,13 +9,15 @@ from packaging.version import parse, Version
import platform
from setuptools import setup, find_packages
from setuptools.command.install import install
from setuptools.command.build import build
import subprocess
from setuptools.command.bdist_egg import bdist_egg
import urllib.request
import urllib.error
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
with open("README.md", "r", encoding="utf-8") as fh:
......@@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn_wheels"
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
......@@ -201,15 +204,17 @@ def get_package_version():
return str(public_version)
class CachedWheelsCommand(install):
class CachedWheelsCommand(_bdist_wheel):
"""
Installer hook to scan for existing wheels that match the current platform environment.
Falls back to building from source if no wheel is found.
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return install.run(self)
return build.run(self)
raise_if_cuda_home_none("flash_attn")
......@@ -223,7 +228,7 @@ class CachedWheelsCommand(install):
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
......@@ -232,17 +237,28 @@ class CachedWheelsCommand(install):
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
install.run(self)
super().run()
setup(
# @pierce - TODO: Revert for official release
name="flash_attn_wheels",
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
......@@ -264,10 +280,10 @@ setup(
],
ext_modules=ext_modules,
cmdclass={
'install': CachedWheelsCommand,
'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension
} if ext_modules else {
'install': CachedWheelsCommand,
'bdist_wheel': CachedWheelsCommand,
},
python_requires=">=3.7",
install_requires=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment