Unverified Commit a216b0bb authored by Luo Chaofan's avatar Luo Chaofan Committed by GitHub
Browse files

fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (#8454)



* fix: ValueError when using FromOriginalModelMixin in subclasses #8440

(cherry picked from commit 92859978436acf844760fc0e992165b489d0180a)

* Update src/diffusers/loaders/single_file_model.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update single_file_model.py

* Update single_file_model.py

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 150142c5
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import inspect import inspect
import re import re
from contextlib import nullcontext from contextlib import nullcontext
...@@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = { ...@@ -72,6 +73,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
} }
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs): def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters parameters = inspect.signature(mapping_fn).parameters
...@@ -149,8 +161,9 @@ class FromOriginalModelMixin: ...@@ -149,8 +161,9 @@ class FromOriginalModelMixin:
``` ```
""" """
class_name = cls.__name__ mapping_class_name = _get_single_file_loadable_mapping_class(cls)
if class_name not in SINGLE_FILE_LOADABLE_CLASSES: # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError( raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
) )
...@@ -195,7 +208,7 @@ class FromOriginalModelMixin: ...@@ -195,7 +208,7 @@ class FromOriginalModelMixin:
revision=revision, revision=revision,
) )
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config: if original_config:
...@@ -207,7 +220,7 @@ class FromOriginalModelMixin: ...@@ -207,7 +220,7 @@ class FromOriginalModelMixin:
if config_mapping_fn is None: if config_mapping_fn is None:
raise ValueError( raise ValueError(
( (
f"`original_config` has been provided for {class_name} but no mapping function" f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in" "was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`" "`diffusers.loaders.single_file_utils`"
) )
...@@ -267,7 +280,7 @@ class FromOriginalModelMixin: ...@@ -267,7 +280,7 @@ class FromOriginalModelMixin:
) )
if not diffusers_format_checkpoint: if not diffusers_format_checkpoint:
raise SingleFileComponentError( raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
) )
ctx = init_empty_weights if is_accelerate_available() else nullcontext ctx = init_empty_weights if is_accelerate_available() else nullcontext
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment