Commit fb9e37ad authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct logging

parent 273f9fee
...@@ -9,10 +9,10 @@ from accelerate import Accelerator ...@@ -9,10 +9,10 @@ from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop,
Compose, Compose,
InterpolationMode, InterpolationMode,
Lambda, Lambda,
CenterCrop,
RandomHorizontalFlip, RandomHorizontalFlip,
Resize, Resize,
ToTensor, ToTensor,
......
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
try: try:
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
except ImportError: except ImportError:
class GLIDE: class GLIDE:
pass pass
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDM
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" PyTorch CLIP model.""" """ PyTorch CLIP model."""
import math import math
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
...@@ -25,12 +24,19 @@ import torch.utils.checkpoint ...@@ -25,12 +24,19 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import tqdm import tqdm
try: try:
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.utils import (
ModelOutput,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
except: except:
print("Transformers is not installed") print("Transformers is not installed")
pass pass
...@@ -38,6 +44,7 @@ except: ...@@ -38,6 +44,7 @@ except:
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
from ..utils import logging
##################### #####################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment