Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dc08ea1c
Commit
dc08ea1c
authored
Mar 15, 2023
by
Tri Dao
Browse files
Support H100 for other CUDA extensions
parent
1b18f1b7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
122 additions
and
88 deletions
+122
-88
csrc/ft_attention/setup.py
csrc/ft_attention/setup.py
+31
-22
csrc/fused_dense_lib/setup.py
csrc/fused_dense_lib/setup.py
+5
-6
csrc/layer_norm/setup.py
csrc/layer_norm/setup.py
+26
-19
csrc/rotary/setup.py
csrc/rotary/setup.py
+26
-19
csrc/xentropy/setup.py
csrc/xentropy/setup.py
+26
-19
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+7
-2
setup.py
setup.py
+1
-1
No files found.
csrc/ft_attention/setup.py
View file @
dc08ea1c
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
import
subprocess
import
sys
import
sys
import
warnings
import
warnings
import
os
import
os
from
packaging.version
import
parse
,
Version
from
setuptools
import
setup
,
find_packages
import
subprocess
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
@@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_
major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_
minor
):
if
(
bare_metal_
version
!=
torch_binary_
version
):
raise
RuntimeError
(
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"not match the version used to compile Pytorch binaries. "
...
@@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
return
nvcc_extra_args
...
@@ -72,15 +72,18 @@ if not torch.cuda.is_available():
...
@@ -72,15 +72,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
and
CUDA_HOME
is
not
None
:
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
==
11
:
if
bare_metal_
version
>=
Version
(
"11.8"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0
;8.6;9.0
"
if
int
(
bare_metal_minor
)
>
0
:
elif
bare_metal_version
>=
Version
(
"11.1"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif
bare_metal_version
==
Version
(
"11.0"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
...
@@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
...
@@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none
(
"--ft_attention"
)
raise_if_cuda_home_none
(
"--ft_attention"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
# cc_flag.append("-gencode")
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
# cc_flag.append("arch=compute_70,code=sm_70")
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"ft_attention is only supported on CUDA 11 and above"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
...
...
csrc/fused_dense_lib/setup.py
View file @
dc08ea1c
import
os
import
os
import
subprocess
import
subprocess
from
packaging.version
import
parse
,
Version
import
torch
import
torch
from
setuptools
import
setup
from
setuptools
import
setup
...
@@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/layer_norm/setup.py
View file @
dc08ea1c
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
sys
import
warnings
import
os
from
packaging.version
import
parse
,
Version
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
subprocess
import
subprocess
import
sys
import
warnings
import
os
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_
major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_
minor
):
if
(
bare_metal_
version
!=
torch_binary_
version
):
raise
RuntimeError
(
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"not match the version used to compile Pytorch binaries. "
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
return
nvcc_extra_args
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
and
CUDA_HOME
is
not
None
:
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
==
11
:
if
bare_metal_
version
>=
Version
(
"11.8"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0
;8.6;9.0
"
if
int
(
bare_metal_minor
)
>
0
:
elif
bare_metal_version
>=
Version
(
"11.1"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif
bare_metal_version
==
Version
(
"11.0"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
...
@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
...
@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none
(
"--fast_layer_norm"
)
raise_if_cuda_home_none
(
"--fast_layer_norm"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"dropout_layer_norm is only supported on CUDA 11 and above"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
...
...
csrc/rotary/setup.py
View file @
dc08ea1c
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
sys
import
warnings
import
os
from
packaging.version
import
parse
,
Version
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
subprocess
import
subprocess
import
sys
import
warnings
import
os
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_
major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_
minor
):
if
(
bare_metal_
version
!=
torch_binary_
version
):
raise
RuntimeError
(
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"not match the version used to compile Pytorch binaries. "
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
return
nvcc_extra_args
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
and
CUDA_HOME
is
not
None
:
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
==
11
:
if
bare_metal_
version
>=
Version
(
"11.8"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0
;8.6;9.0
"
if
int
(
bare_metal_minor
)
>
0
:
elif
bare_metal_version
>=
Version
(
"11.1"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif
bare_metal_version
==
Version
(
"11.0"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
...
@@ -91,10 +92,16 @@ ext_modules = []
...
@@ -91,10 +92,16 @@ ext_modules = []
raise_if_cuda_home_none
(
"rotary_emb"
)
raise_if_cuda_home_none
(
"rotary_emb"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"rotary_emb is only supported on CUDA 11 and above"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
...
...
csrc/xentropy/setup.py
View file @
dc08ea1c
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import
sys
import
warnings
import
os
from
packaging.version
import
parse
,
Version
import
torch
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
subprocess
import
subprocess
import
sys
import
warnings
import
os
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
print
(
"
\n
Compiling cuda extensions with"
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_
major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_
minor
):
if
(
bare_metal_
version
!=
torch_binary_
version
):
raise
RuntimeError
(
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"not match the version used to compile Pytorch binaries. "
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
return
nvcc_extra_args
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
...
@@ -72,15 +70,18 @@ if not torch.cuda.is_available():
"If you wish to cross-compile for a single specific architecture,
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
and
CUDA_HOME
is
not
None
:
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
==
11
:
if
bare_metal_
version
>=
Version
(
"11.8"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0
;8.6;9.0
"
if
int
(
bare_metal_minor
)
>
0
:
elif
bare_metal_version
>=
Version
(
"11.1"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif
bare_metal_version
==
Version
(
"11.0"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
...
@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
...
@@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none
(
"--xentropy"
)
raise_if_cuda_home_none
(
"--xentropy"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"xentropy is only supported on CUDA 11 and above"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"arch=compute_70,code=sm_70"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
...
...
flash_attn/ops/fused_dense.py
View file @
dc08ea1c
...
@@ -421,6 +421,8 @@ class FusedMLP(nn.Module):
...
@@ -421,6 +421,8 @@ class FusedMLP(nn.Module):
'auto': heuristic will be picked automatically:
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
is slower than the unfused version.
return_residual: whether to return the input x along with the output. This is for
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
to fuse the backward of nn.Linear with the residual connection.
...
@@ -442,6 +444,9 @@ class FusedMLP(nn.Module):
...
@@ -442,6 +444,9 @@ class FusedMLP(nn.Module):
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
dtype
=
x
.
dtype
if
not
torch
.
is_autocast_enabled
()
else
torch
.
get_autocast_gpu_dtype
()
if
self
.
heuristic
==
'auto'
:
if
self
.
heuristic
==
'auto'
:
if
self
.
activation
==
'gelu_approx'
:
if
self
.
activation
==
'gelu_approx'
:
if
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
9
,
0
):
heuristic
=
-
1
else
:
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
cuda_ver
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
'.'
)))
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
heuristic
=
0
if
cuda_ver
>=
(
11
,
8
)
else
(
1
if
dtype
==
torch
.
float16
else
-
1
)
else
:
else
:
...
...
setup.py
View file @
dc08ea1c
...
@@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn")
...
@@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn")
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.0"
):
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11"
)
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11
and above
"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment