Commit 27039cd3 authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents 8841d0d1 1f4d817c
......@@ -3,12 +3,11 @@ from torch import nn
from diffusers import (
ClassifierFreeGuidanceScheduler,
CLIPTextModel,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
......
This diff is collapsed.
......@@ -13,9 +13,10 @@ model_id = "fusing/glide-base"
pipeline = DiffusionPipeline.from_pretrained(model_id)
# run inference (text-conditioned denoising + upscaling)
img = pipeline("a clip art of a hugging face", generator)
img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL
img = img.squeeze(0)
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img)
......
......@@ -5,11 +5,9 @@
__version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
......
......@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel
from .vqvae import VQModel
This diff is collapsed.
This diff is collapsed.
......@@ -34,13 +34,13 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
},
}
......@@ -83,24 +83,25 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module")
for name, (library_name, class_name) in model_index_dict.items():
importable_classes = LOADABLE_CLASSES[library_name]
# TODO: Suraj
if library_name == self.__module__:
library_name = self
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
save_method_name = importable_classes[class_name][0]
save_method = getattr(getattr(self, name), save_method_name)
save_method(os.path.join(save_directory, name))
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class)
if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
......@@ -113,6 +114,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(
......@@ -128,11 +130,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
# 2. Get class name and module candidates to load custom models
class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
......@@ -146,6 +149,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {}
# 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
......@@ -155,6 +159,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
......@@ -167,12 +172,15 @@ class DiffusionPipeline(ConfigMixin):
load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name))
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment