Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56b83cf0
Unverified
Commit
56b83cf0
authored
Jun 22, 2022
by
Arthur
Committed by
GitHub
Jun 22, 2022
Browse files
initial commit (#17818)
parent
13570381
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
106 deletions
+1
-106
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-106
No files found.
src/transformers/modeling_utils.py
View file @
56b83cf0
...
...
@@ -32,6 +32,7 @@ from torch import Tensor, device, nn
from
torch.nn
import
CrossEntropyLoss
from
requests
import
HTTPError
from
transformers.utils.hub
import
convert_file_size_to_int
,
get_checkpoint_shard_files
from
.activations
import
get_activation
from
.configuration_utils
import
PretrainedConfig
...
...
@@ -205,40 +206,6 @@ def get_state_dict_dtype(state_dict):
return
next
(
state_dict
.
values
()).
dtype
def
convert_file_size_to_int
(
size
:
Union
[
int
,
str
]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
```
"""
if
isinstance
(
size
,
int
):
return
size
if
size
.
upper
().
endswith
(
"GIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
30
)
if
size
.
upper
().
endswith
(
"MIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
20
)
if
size
.
upper
().
endswith
(
"KIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
10
)
if
size
.
upper
().
endswith
(
"GB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
9
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
if
size
.
upper
().
endswith
(
"MB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
6
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
if
size
.
upper
().
endswith
(
"KB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
3
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
raise
ValueError
(
"`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'."
)
def
dtype_byte_size
(
dtype
):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
...
...
@@ -324,78 +291,6 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
return
shards
,
index
def
get_checkpoint_shard_files
(
pretrained_model_name_or_path
,
index_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
local_files_only
=
False
,
use_auth_token
=
None
,
user_agent
=
None
,
revision
=
None
,
mirror
=
None
,
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
with
open
(
index_filename
,
"r"
)
as
f
:
index
=
json
.
loads
(
f
.
read
())
shard_filenames
=
sorted
(
list
(
set
(
index
[
"weight_map"
].
values
())))
sharded_metadata
=
index
[
"metadata"
]
sharded_metadata
[
"all_checkpoint_keys"
]
=
list
(
index
[
"weight_map"
].
keys
())
# First, let's deal with local folder.
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
shard_filenames
=
[
os
.
path
.
join
(
pretrained_model_name_or_path
,
f
)
for
f
in
shard_filenames
]
return
shard_filenames
,
sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames
=
[]
for
shard_filename
in
shard_filenames
:
shard_url
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
shard_filename
,
revision
=
revision
,
mirror
=
mirror
)
try
:
# Load from URL
cached_filename
=
cached_path
(
shard_url
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
shard_filename
}
which is "
"required according to the checkpoint index."
)
except
HTTPError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
shard_filename
}
. You should try"
" again after checking your internet connection."
)
cached_filenames
.
append
(
cached_filename
)
return
cached_filenames
,
sharded_metadata
def
load_sharded_checkpoint
(
model
,
folder
,
strict
=
True
):
"""
This is the same as
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment