Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0e7769c8
Commit
0e7769c8
authored
Jun 02, 2023
by
Pierce Freeman
Browse files
Guessing wheel URL
parent
e1faefce
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
4 deletions
+6
-4
setup.py
setup.py
+6
-4
No files found.
setup.py
View file @
0e7769c8
...
@@ -47,18 +47,22 @@ class CustomInstallCommand(install):
...
@@ -47,18 +47,22 @@ class CustomInstallCommand(install):
raise_if_cuda_home_none
(
"flash_attn"
)
raise_if_cuda_home_none
(
"flash_attn"
)
# Determine the version numbers that will be used to determine the correct wheel
# Determine the version numbers that will be used to determine the correct wheel
_
,
cuda_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
cuda_version
_raw
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
torch_version
=
torch
.
__version__
torch_version
=
torch
.
__version__
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
flash_version
=
get_package_version
()
cuda_version
=
f
"
{
cuda_version_raw
.
major
}{
cuda_version_raw
.
minor
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'flash_attn-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_filename
=
f
'flash_attn-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
#tag_name=f"v{flash_version}",
# HACK
tag_name
=
f
"v0.0.3"
,
wheel_name
=
wheel_filename
wheel_name
=
wheel_filename
)
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
...
@@ -70,8 +74,6 @@ class CustomInstallCommand(install):
...
@@ -70,8 +74,6 @@ class CustomInstallCommand(install):
#install.run(self)
#install.run(self)
raise
ValueError
raise
ValueError
raise
ValueError
def
get_cuda_bare_metal_version
(
cuda_dir
):
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment