Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
50896ec5
Unverified
Commit
50896ec5
authored
Mar 14, 2024
by
Chirag Jain
Committed by
GitHub
Mar 13, 2024
Browse files
Make nvcc threads configurable via environment variable (#885)
parent
6c9e60de
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
14 additions
and
7 deletions
+14
-7
csrc/ft_attention/setup.py
csrc/ft_attention/setup.py
+2
-1
csrc/fused_dense_lib/setup.py
csrc/fused_dense_lib/setup.py
+2
-1
csrc/fused_softmax/setup.py
csrc/fused_softmax/setup.py
+2
-1
csrc/layer_norm/setup.py
csrc/layer_norm/setup.py
+2
-1
csrc/rotary/setup.py
csrc/rotary/setup.py
+2
-1
csrc/xentropy/setup.py
csrc/xentropy/setup.py
+2
-1
setup.py
setup.py
+2
-1
No files found.
csrc/ft_attention/setup.py
View file @
50896ec5
...
@@ -55,7 +55,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -55,7 +55,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.2"
):
if
bare_metal_version
>=
Version
(
"11.2"
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/fused_dense_lib/setup.py
View file @
50896ec5
...
@@ -19,7 +19,8 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -19,7 +19,8 @@ def get_cuda_bare_metal_version(cuda_dir):
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.2"
):
if
bare_metal_version
>=
Version
(
"11.2"
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/fused_softmax/setup.py
View file @
50896ec5
...
@@ -22,7 +22,8 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -22,7 +22,8 @@ def get_cuda_bare_metal_version(cuda_dir):
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
if
int
(
bare_metal_major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/layer_norm/setup.py
View file @
50896ec5
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.2"
):
if
bare_metal_version
>=
Version
(
"11.2"
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/rotary/setup.py
View file @
50896ec5
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.2"
):
if
bare_metal_version
>=
Version
(
"11.2"
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
csrc/xentropy/setup.py
View file @
50896ec5
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
...
@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.2"
):
if
bare_metal_version
>=
Version
(
"11.2"
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
return
nvcc_extra_args
return
nvcc_extra_args
...
...
setup.py
View file @
50896ec5
...
@@ -83,7 +83,8 @@ def check_if_cuda_home_none(global_option: str) -> None:
...
@@ -83,7 +83,8 @@ def check_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"4"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
cmdclass
=
{}
cmdclass
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment