Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56b83cf0
Unverified
Commit
56b83cf0
authored
Jun 22, 2022
by
Arthur
Committed by
GitHub
Jun 22, 2022
Browse files
initial commit (#17818)
parent
13570381
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
106 deletions
+1
-106
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-106
No files found.
src/transformers/modeling_utils.py
View file @
56b83cf0
...
...
@@ -32,6 +32,7 @@ from torch import Tensor, device, nn
from
torch.nn
import
CrossEntropyLoss
from
requests
import
HTTPError
from
transformers.utils.hub
import
convert_file_size_to_int
,
get_checkpoint_shard_files
from
.activations
import
get_activation
from
.configuration_utils
import
PretrainedConfig
...
...
@@ -205,40 +206,6 @@ def get_state_dict_dtype(state_dict):
return
next
(
state_dict
.
values
()).
dtype
def
convert_file_size_to_int
(
size
:
Union
[
int
,
str
]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
```
"""
if
isinstance
(
size
,
int
):
return
size
if
size
.
upper
().
endswith
(
"GIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
30
)
if
size
.
upper
().
endswith
(
"MIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
20
)
if
size
.
upper
().
endswith
(
"KIB"
):
return
int
(
size
[:
-
3
])
*
(
2
**
10
)
if
size
.
upper
().
endswith
(
"GB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
9
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
if
size
.
upper
().
endswith
(
"MB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
6
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
if
size
.
upper
().
endswith
(
"KB"
):
int_size
=
int
(
size
[:
-
2
])
*
(
10
**
3
)
return
int_size
//
8
if
size
.
endswith
(
"b"
)
else
int_size
raise
ValueError
(
"`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'."
)
def
dtype_byte_size
(
dtype
):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
...
...
@@ -324,78 +291,6 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
return
shards
,
index
def
get_checkpoint_shard_files
(
pretrained_model_name_or_path
,
index_filename
,
cache_dir
=
None
,
force_download
=
False
,
proxies
=
None
,
resume_download
=
False
,
local_files_only
=
False
,
use_auth_token
=
None
,
user_agent
=
None
,
revision
=
None
,
mirror
=
None
,
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
with
open
(
index_filename
,
"r"
)
as
f
:
index
=
json
.
loads
(
f
.
read
())
shard_filenames
=
sorted
(
list
(
set
(
index
[
"weight_map"
].
values
())))
sharded_metadata
=
index
[
"metadata"
]
sharded_metadata
[
"all_checkpoint_keys"
]
=
list
(
index
[
"weight_map"
].
keys
())
# First, let's deal with local folder.
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
shard_filenames
=
[
os
.
path
.
join
(
pretrained_model_name_or_path
,
f
)
for
f
in
shard_filenames
]
return
shard_filenames
,
sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames
=
[]
for
shard_filename
in
shard_filenames
:
shard_url
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
shard_filename
,
revision
=
revision
,
mirror
=
mirror
)
try
:
# Load from URL
cached_filename
=
cached_path
(
shard_url
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
shard_filename
}
which is "
"required according to the checkpoint index."
)
except
HTTPError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
shard_filename
}
. You should try"
" again after checking your internet connection."
)
cached_filenames
.
append
(
cached_filename
)
return
cached_filenames
,
sharded_metadata
def
load_sharded_checkpoint
(
model
,
folder
,
strict
=
True
):
"""
This is the same as
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment