Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b4aa87db
Unverified
Commit
b4aa87db
authored
May 05, 2023
by
Nicolas Patry
Committed by
GitHub
May 05, 2023
Browse files
fea(server): decrease convert RAM requirements (#286)
parent
3314a46d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
14 deletions
+15
-14
server/text_generation_server/utils/convert.py
server/text_generation_server/utils/convert.py
+15
-14
No files found.
server/text_generation_server/utils/convert.py
View file @
b4aa87db
...
...
@@ -9,6 +9,7 @@ from datetime import timedelta
from
loguru
import
logger
from
pathlib
import
Path
from
safetensors.torch
import
load_file
,
save_file
from
safetensors
import
safe_open
from
typing
import
Dict
,
List
...
...
@@ -46,11 +47,11 @@ def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
tensors
.
pop
(
name
)
def
convert_file
(
pt_file
:
Path
,
s
t
_file
:
Path
):
def
convert_file
(
pt_file
:
Path
,
s
f
_file
:
Path
):
"""
Convert a pytorch file to a safetensors file
"""
logger
.
info
(
f
"Convert
{
pt_file
}
to
{
s
t
_file
}
."
)
logger
.
info
(
f
"Convert
{
pt_file
}
to
{
s
f
_file
}
."
)
pt_state
=
torch
.
load
(
pt_file
,
map_location
=
"cpu"
)
if
"state_dict"
in
pt_state
:
...
...
@@ -61,28 +62,28 @@ def convert_file(pt_file: Path, st_file: Path):
# Tensors need to be contiguous
pt_state
=
{
k
:
v
.
contiguous
()
for
k
,
v
in
pt_state
.
items
()}
s
t
_file
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
save_file
(
pt_state
,
str
(
s
t
_file
),
metadata
=
{
"format"
:
"pt"
})
s
f
_file
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
save_file
(
pt_state
,
str
(
s
f
_file
),
metadata
=
{
"format"
:
"pt"
})
# Check that both files are close in size
check_file_size
(
pt_file
,
s
t
_file
)
check_file_size
(
pt_file
,
s
f
_file
)
# Load safetensors state
st_state
=
load_file
(
str
(
st_file
))
for
k
in
st_state
:
for
k
in
pt_state
:
pt_tensor
=
pt_state
[
k
]
st_tensor
=
st_state
[
k
]
if
not
torch
.
equal
(
pt_tensor
,
st_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
with
safe_open
(
sf_file
,
framework
=
"pt"
)
as
f
:
sf_tensor
=
f
.
get_tensor
(
k
)
if
not
torch
.
equal
(
pt_tensor
,
sf_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
def
convert_files
(
pt_files
:
List
[
Path
],
s
t
_files
:
List
[
Path
]):
assert
len
(
pt_files
)
==
len
(
s
t
_files
)
def
convert_files
(
pt_files
:
List
[
Path
],
s
f
_files
:
List
[
Path
]):
assert
len
(
pt_files
)
==
len
(
s
f
_files
)
N
=
len
(
pt_files
)
# We do this instead of using tqdm because we want to parse the logs with the launcher
for
i
,
(
pt_file
,
sf_file
)
in
enumerate
(
zip
(
pt_files
,
s
t
_files
)):
for
i
,
(
pt_file
,
sf_file
)
in
enumerate
(
zip
(
pt_files
,
s
f
_files
)):
start
=
datetime
.
datetime
.
now
()
convert_file
(
pt_file
,
sf_file
)
elapsed
=
datetime
.
datetime
.
now
()
-
start
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment