"vscode:/vscode.git/clone" did not exist on "35fd84be2745a1d1d288b16dae8baf766c7af294"
Commit 0e7769c8 authored by Pierce Freeman's avatar Pierce Freeman
Browse files

Guessing wheel URL

parent e1faefce
...@@ -47,18 +47,22 @@ class CustomInstallCommand(install): ...@@ -47,18 +47,22 @@ class CustomInstallCommand(install):
raise_if_cuda_home_none("flash_attn") raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel # Determine the version numbers that will be used to determine the correct wheel
_, cuda_version = get_cuda_bare_metal_version(CUDA_HOME) _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_version = torch.__version__ torch_version = torch.__version__
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform() platform_name = get_platform()
flash_version = get_package_version() flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
# Determine wheel URL based on CUDA version, torch version, python version and OS # Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format( wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}", #tag_name=f"v{flash_version}",
# HACK
tag_name=f"v0.0.3",
wheel_name=wheel_filename wheel_name=wheel_filename
) )
print("Guessing wheel URL: ", wheel_url)
try: try:
urllib.request.urlretrieve(wheel_url, wheel_filename) urllib.request.urlretrieve(wheel_url, wheel_filename)
...@@ -70,8 +74,6 @@ class CustomInstallCommand(install): ...@@ -70,8 +74,6 @@ class CustomInstallCommand(install):
#install.run(self) #install.run(self)
raise ValueError raise ValueError
raise ValueError
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment