Unverified Commit 2b7d4a5c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DeviceMap] Make sure stable diffusion can be loaded from older trans… (#860)

[DeviceMap] Make sure stable diffusion can be loaded from older transformers versiosn
parent 93a81a3f
......@@ -26,6 +26,7 @@ import torch
import diffusers
import PIL
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
......@@ -45,6 +46,7 @@ from .utils import (
if is_transformers_available():
import transformers
from transformers import PreTrainedModel
......@@ -505,11 +507,14 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
if (
issubclass(class_obj, diffusers.ModelMixin)
or is_transformers_available()
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
):
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
# check if the module is in a subdirectory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment