Unverified Commit a3faf3f2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix: legacy model mapping (#8416)

* fix: legacy model mapping

* remove print
parent 867a2b0c
...@@ -1057,6 +1057,9 @@ class LegacyModelMixin(ModelMixin): ...@@ -1057,6 +1057,9 @@ class LegacyModelMixin(ModelMixin):
# To prevent depedency import problem. # To prevent depedency import problem.
from .model_loading_utils import _fetch_remapped_cls_from_config from .model_loading_utils import _fetch_remapped_cls_from_config
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
kwargs_copy = kwargs.copy()
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None) resume_download = kwargs.pop("resume_download", None)
...@@ -1094,4 +1097,4 @@ class LegacyModelMixin(ModelMixin): ...@@ -1094,4 +1097,4 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping # resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls) remapped_class = _fetch_remapped_cls_from_config(config, cls)
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment