Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
b7e7e668
Unverified
Commit
b7e7e668
authored
Aug 07, 2023
by
LZHgrla
Committed by
GitHub
Aug 07, 2023
Browse files
[Feature] Add script to split HuggingFace model to the smallest sharded checkpoints (#199)
* add get_small_sharded_hf.py * fix pre-commit
parent
0ed1e4d4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
0 deletions
+63
-0
lmdeploy/lite/apis/get_small_sharded_hf.py
lmdeploy/lite/apis/get_small_sharded_hf.py
+63
-0
No files found.
lmdeploy/lite/apis/get_small_sharded_hf.py
0 → 100644
View file @
b7e7e668
# Copyright (c) OpenMMLab. All rights reserved.
import
argparse
import
copy
import
json
import
os
import
shutil
import
torch
from
mmengine.utils
import
mkdir_or_exist
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Convert a hugging face model to the smallest sharded one'
)
parser
.
add_argument
(
'src_dir'
,
help
=
'the directory of the model'
)
parser
.
add_argument
(
'dst_dir'
,
help
=
'the directory to save the new model'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
mkdir_or_exist
(
args
.
dst_dir
)
all_files
=
os
.
listdir
(
args
.
src_dir
)
for
name
in
all_files
:
if
not
name
.
startswith
((
'pytorch_model'
,
'.'
)):
src_path
=
os
.
path
.
join
(
args
.
src_dir
,
name
)
dst_path
=
os
.
path
.
join
(
args
.
dst_dir
,
name
)
shutil
.
copy
(
src_path
,
dst_path
)
with
open
(
os
.
path
.
join
(
args
.
src_dir
,
'pytorch_model.bin.index.json'
))
as
f
:
index
=
json
.
load
(
f
)
n_shard
=
len
(
index
[
'weight_map'
])
new_index
=
copy
.
deepcopy
(
index
)
new_index
[
'weight_map'
]
=
{}
cnt
=
1
checkpoints
=
set
(
index
[
'weight_map'
].
values
())
for
ckpt
in
checkpoints
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
args
.
src_dir
,
ckpt
),
map_location
=
'cuda'
)
keys
=
sorted
(
list
(
state_dict
.
keys
()))
for
k
in
keys
:
new_state_dict_name
=
'pytorch_model-{:05d}-of-{:05d}.bin'
.
format
(
cnt
,
n_shard
)
new_index
[
'weight_map'
][
k
]
=
new_state_dict_name
new_state_dict
=
{
k
:
state_dict
[
k
]}
torch
.
save
(
new_state_dict
,
os
.
path
.
join
(
args
.
dst_dir
,
new_state_dict_name
))
cnt
+=
1
del
state_dict
torch
.
cuda
.
empty_cache
()
with
open
(
os
.
path
.
join
(
args
.
dst_dir
,
'pytorch_model.bin.index.json'
),
'w'
)
as
f
:
json
.
dump
(
new_index
,
f
)
assert
new_index
[
'weight_map'
].
keys
()
==
index
[
'weight_map'
].
keys
(
),
'Mismatch on `weight_map`!'
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment