"torchvision/vscode:/vscode.git/clone" did not exist on "6c56029046cffecf8d2b0e6094507fbee0c0daa2"
Commit 494b2aa4 authored by Pierce Freeman's avatar Pierce Freeman
Browse files

Add notes to github action workflow

parent 8d60c373
# This workflow will upload a Python Package to Release asset # This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
name: Python Package
on: on:
create: create:
......
...@@ -57,6 +57,14 @@ To install: ...@@ -57,6 +57,14 @@ To install:
pip install flash-attn pip install flash-attn
``` ```
If you see an error about `ModuleNotFoundError: No module named 'torch'`, it's likely because of pypi's installation isolation.
To fix you can run:
```sh
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source: Alternatively you can compile from source:
``` ```
python setup.py install python setup.py install
......
__version__ = "1.0.7" __version__ = "1.0.8"
## flash-attn-builder
Basic build utilities for flash-attn.
import os
import sys
import urllib
import setuptools.build_meta
from setuptools.command.install import install
from packaging.version import parse, Version
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
class CustomBuildBackend(setuptools.build_meta._BuildMetaBackend):
def build_wheel(self, wheel_directory, config_settings=None, metadata_directory=None):
this_file_directory = os.path.dirname(os.path.abspath(__file__))
print(f'This file is located in: {this_file_directory}')
sys.argv = [
*sys.argv[:1],
*self._global_args(config_settings),
*self._arbitrary_args(config_settings),
]
with setuptools.build_meta.no_install_setup_requires():
self.run_setup()
print("OS", os.environ["FLASH_ATTENTION_WHEEL_URL"])
print("config_settings", config_settings)
print("metadata_directory", metadata_directory)
raise ValueError
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().build_wheel(wheel_directory, config_settings, metadata_directory)
_BACKEND = CustomBuildBackend() # noqa
get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel
get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist
prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel
build_wheel = _BACKEND.build_wheel
build_sdist = _BACKEND.build_sdist
[tool.poetry]
name = "flash-attn-builder"
version = "0.1.0"
description = ""
authors = ["Pierce Freeman <pierce@freeman.vc>"]
readme = "README.md"
packages = [{include = "flash_attn_builder"}]
[tool.poetry.dependencies]
python = "^3.10"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[build-system]
requires = ["ninja", "packaging", "setuptools", "wheel"]
build-backend = "setuptools.build_meta"
...@@ -9,13 +9,15 @@ from packaging.version import parse, Version ...@@ -9,13 +9,15 @@ from packaging.version import parse, Version
import platform import platform
from setuptools import setup, find_packages from setuptools import setup, find_packages
from setuptools.command.install import install from setuptools.command.build import build
import subprocess import subprocess
from setuptools.command.bdist_egg import bdist_egg
import urllib.request import urllib.request
import urllib.error import urllib.error
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
with open("README.md", "r", encoding="utf-8") as fh: with open("README.md", "r", encoding="utf-8") as fh:
...@@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh: ...@@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn_wheels"
# @pierce - TODO: Update for proper release # @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}" BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
...@@ -201,15 +204,17 @@ def get_package_version(): ...@@ -201,15 +204,17 @@ def get_package_version():
return str(public_version) return str(public_version)
class CachedWheelsCommand(install): class CachedWheelsCommand(_bdist_wheel):
""" """
Installer hook to scan for existing wheels that match the current platform environment. The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
Falls back to building from source if no wheel is found. find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
""" """
def run(self): def run(self):
if FORCE_BUILD: if FORCE_BUILD:
return install.run(self) return build.run(self)
raise_if_cuda_home_none("flash_attn") raise_if_cuda_home_none("flash_attn")
...@@ -223,7 +228,7 @@ class CachedWheelsCommand(install): ...@@ -223,7 +228,7 @@ class CachedWheelsCommand(install):
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}" torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS # Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format( wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}", tag_name=f"v{flash_version}",
wheel_name=wheel_filename wheel_name=wheel_filename
...@@ -232,17 +237,28 @@ class CachedWheelsCommand(install): ...@@ -232,17 +237,28 @@ class CachedWheelsCommand(install):
try: try:
urllib.request.urlretrieve(wheel_url, wheel_filename) urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename) # Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError: except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source
install.run(self) super().run()
setup( setup(
# @pierce - TODO: Revert for official release # @pierce - TODO: Revert for official release
name="flash_attn_wheels", name=PACKAGE_NAME,
version=get_package_version(), version=get_package_version(),
packages=find_packages( packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
...@@ -264,10 +280,10 @@ setup( ...@@ -264,10 +280,10 @@ setup(
], ],
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={ cmdclass={
'install': CachedWheelsCommand, 'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension "build_ext": BuildExtension
} if ext_modules else { } if ext_modules else {
'install': CachedWheelsCommand, 'bdist_wheel': CachedWheelsCommand,
}, },
python_requires=">=3.7", python_requires=">=3.7",
install_requires=[ install_requires=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment