Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ece7ffa4
Unverified
Commit
ece7ffa4
authored
Jun 19, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 19, 2023
Browse files
feat(server): improve flash attention import errors (#465)
@lewtun, is this enough? Closes #458 Closes #456
parent
f59fb8b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
36 deletions
+46
-36
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+40
-33
server/text_generation_server/utils/convert.py
server/text_generation_server/utils/convert.py
+2
-2
server/text_generation_server/utils/hub.py
server/text_generation_server/utils/hub.py
+4
-1
No files found.
server/text_generation_server/models/__init__.py
View file @
ece7ffa4
...
...
@@ -18,11 +18,43 @@ from text_generation_server.models.santacoder import SantaCoder
from
text_generation_server.models.t5
import
T5Sharded
from
text_generation_server.models.gpt_neox
import
GPTNeoxSharded
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Disable gradients
torch
.
set_grad_enabled
(
False
)
__all__
=
[
"Model"
,
"BLOOMSharded"
,
"CausalLM"
,
"FlashCausalLM"
,
"GalacticaSharded"
,
"Seq2SeqLM"
,
"SantaCoder"
,
"OPTSharded"
,
"T5Sharded"
,
"get_model"
,
]
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires CUDA and Flash Attention kernels to be installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
try
:
if
(
torch
.
cuda
.
is_available
()
and
not
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
):
if
not
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
:
if
not
torch
.
cuda
.
is_available
():
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires CUDA. No compatible CUDA devices found."
)
raise
ImportError
(
"CUDA is not available"
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
is_sm8x
=
major
==
8
and
minor
>=
0
...
...
@@ -30,6 +62,10 @@ try:
supported
=
is_sm75
or
is_sm8x
or
is_sm90
if
not
supported
:
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
...
...
@@ -52,41 +88,12 @@ except ImportError:
)
FLASH_ATTENTION
=
False
__all__
=
[
"Model"
,
"BLOOMSharded"
,
"CausalLM"
,
"FlashCausalLM"
,
"GalacticaSharded"
,
"Seq2SeqLM"
,
"SantaCoder"
,
"OPTSharded"
,
"T5Sharded"
,
"get_model"
,
]
if
FLASH_ATTENTION
:
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashRWSharded
)
__all__
.
append
(
FlashSantacoderSharded
)
__all__
.
append
(
FlashLlama
)
FLASH_ATT_ERROR_MESSAGE
=
(
"{} requires Flash Attention CUDA kernels to be installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Disable gradients
torch
.
set_grad_enabled
(
False
)
def
get_model
(
model_id
:
str
,
...
...
server/text_generation_server/utils/convert.py
View file @
ece7ffa4
...
...
@@ -16,9 +16,9 @@ def check_file_size(source_file: Path, target_file: Path):
source_file_size
=
source_file
.
stat
().
st_size
target_file_size
=
target_file
.
stat
().
st_size
if
(
source_file_size
-
target_file_size
)
/
source_file_size
>
0.0
1
:
if
(
source_file_size
-
target_file_size
)
/
source_file_size
>
0.0
5
:
raise
RuntimeError
(
f
"""The file size different is more than
1
%:
f
"""The file size different is more than
5
%:
-
{
source_file
}
:
{
source_file_size
}
-
{
target_file
}
:
{
target_file_size
}
"""
...
...
server/text_generation_server/utils/hub.py
View file @
ece7ffa4
...
...
@@ -26,7 +26,10 @@ def weight_hub_files(
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)
and
len
(
s
.
rfilename
.
split
(
"/"
))
==
1
if
s
.
rfilename
.
endswith
(
extension
)
and
len
(
s
.
rfilename
.
split
(
"/"
))
==
1
and
"arguments"
not
in
s
.
rfilename
and
"args"
not
in
s
.
rfilename
]
if
not
filenames
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment