Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ef6d8c75
Commit
ef6d8c75
authored
Aug 21, 2023
by
Tri Dao
Browse files
[GPT] Fix loading weights from HF hub
parent
a8c35b4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
2 deletions
+9
-2
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+8
-0
tests/models/test_bert.py
tests/models/test_bert.py
+1
-2
No files found.
flash_attn/utils/pretrained.py
View file @
ef6d8c75
...
@@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
...
@@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
)
is_sharded
=
True
is_sharded
=
True
load_safe
=
True
load_safe
=
True
else
:
# Try loading from HF hub instead of from local files
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
None
:
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_INDEX_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
if
resolved_archive_file
is
not
None
:
is_sharded
=
True
if
resolved_archive_file
is
None
:
if
resolved_archive_file
is
None
:
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
raise
EnvironmentError
(
f
"Model name
{
model_name
}
was not found."
)
...
...
tests/models/test_bert.py
View file @
ef6d8c75
...
@@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype):
...
@@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype):
return
model_hf
return
model_hf
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_non_optimized
(
model_name
):
def
test_bert_non_optimized
(
model_name
):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment