"git@developer.sourcefind.cn:gaoqiong/flash-attention.git" did not exist on "0c2fb252646f7bca12c645e4c79fb0515e0d99a5"
Commit 961c6977 authored by thomwolf's avatar thomwolf
Browse files

@julien-c proposal for TF/PT compat in hf_buckets

parent d311f87b
...@@ -303,11 +303,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -303,11 +303,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif os.path.isfile(pretrained_model_name_or_path + ".index"): elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
else: else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME) archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME))
if from_pt:
raise EnvironmentError(
"Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name."
)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
......
...@@ -421,11 +421,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -421,11 +421,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) )
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
else: else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME))
if from_tf:
raise EnvironmentError(
"Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name."
)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment