Commit ef6d8c75 authored by Tri Dao's avatar Tri Dao
Browse files

[GPT] Fix loading weights from HF hub

parent a8c35b4f
...@@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None): ...@@ -44,6 +44,14 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
) )
is_sharded = True is_sharded = True
load_safe = True load_safe = True
else: # Try loading from HF hub instead of from local files
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None: if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.") raise EnvironmentError(f"Model name {model_name} was not found.")
......
...@@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype): ...@@ -43,8 +43,7 @@ def get_hf_models(model_name, config, dtype):
return model_hf return model_hf
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_non_optimized(model_name): def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the """Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment