Commit 5e469978 authored by Pierce Freeman's avatar Pierce Freeman
Browse files

Allow fallback install

parent dab99053
...@@ -109,6 +109,7 @@ jobs: ...@@ -109,6 +109,7 @@ jobs:
- name: Build wheel - name: Build wheel
run: | run: |
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FORCE_CUDA="1" export FORCE_CUDA="1"
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
......
...@@ -44,6 +44,9 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down ...@@ -44,6 +44,9 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
class CustomInstallCommand(install): class CustomInstallCommand(install):
def run(self): def run(self):
if os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE":
return install.run(self)
raise_if_cuda_home_none("flash_attn") raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel # Determine the version numbers that will be used to determine the correct wheel
...@@ -59,7 +62,7 @@ class CustomInstallCommand(install): ...@@ -59,7 +62,7 @@ class CustomInstallCommand(install):
wheel_url = BASE_WHEEL_URL.format( wheel_url = BASE_WHEEL_URL.format(
#tag_name=f"v{flash_version}", #tag_name=f"v{flash_version}",
# HACK # HACK
tag_name=f"v0.0.3", tag_name=f"v0.0.5",
wheel_name=wheel_filename wheel_name=wheel_filename
) )
print("Guessing wheel URL: ", wheel_url) print("Guessing wheel URL: ", wheel_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment