Commit e29fc446 authored by Nathan Lambert's avatar Nathan Lambert
Browse files
parents 7b4e049e 6e456b7a
......@@ -48,7 +48,7 @@ The class provides functionality to compute previous image according to alpha, b
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2
*Examples*: Glide, Latent-Diffusion, Imagen, DALL-E 2
<p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
......@@ -190,7 +190,7 @@ image_pil.save("test.png")
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...)
### 2. `diffusers` as a collection of popular Diffusion systems (Glide, Dalle, ...)
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
......@@ -249,24 +249,24 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
#### **Text to speech with GradTTS and BDDM**
#### **Text to speech with GradTTS and BDDMPipeline**
```python
import torch
from diffusers import BDDM, GradTTS
from diffusers import BDDMPipeline, GradTTSPipeline
torch_device = "cuda"
# load grad tts and bddm pipelines
grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech")
grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
text = "Hello world, I missed you so much."
# generate mel spectograms using text
mel_spec = grad_tts(text, torch_device=torch_device)
# generate the speech by passing mel spectograms to BDDM pipeline
# generate the speech by passing mel spectograms to BDDMPipeline pipeline
generator = torch.manual_seed(42)
audio = bddm(mel_spec, generator, torch_device=torch_device)
......@@ -278,13 +278,14 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
## TODO
- Create common API for models [ ]
- Add tests for models [ ]
- Adapt schedulers for training [ ]
- Write google colab for training [ ]
- Write docs / Think about how to structure docs [ ]
- Add tests to circle ci [ ]
- Add [Diffusion LM models](https://arxiv.org/pdf/2205.14217.pdf) [ ]
- Add more vision models [ ]
- Add more speech models [ ]
- Add RL model [ ]
- [ ] Create common API for models
- [ ] Add tests for models
- [ ] Adapt schedulers for training
- [ ] Write google colab for training
- [ ] Write docs / Think about how to structure docs
- [ ] Add tests to circle ci
- [ ] Add [Diffusion LM models](https://arxiv.org/pdf/2205.14217.pdf)
- [ ] Add more vision models
- [ ] Add more speech models
- [ ] Add RL model
- [ ] Add FID and KID metrics
......@@ -10,7 +10,7 @@ python -m torch.distributed.launch \
train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--resolution=64 \
--output_path="flowers-ddpm" \
--output_dir="flowers-ddpm" \
--batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
......@@ -34,7 +34,7 @@ python -m torch.distributed.launch \
train_unconditional.py \
--dataset="huggan/pokemon" \
--resolution=64 \
--output_path="pokemon-ddpm" \
--output_dir="pokemon-ddpm" \
--batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
......
import argparse
import os
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetLDMModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Lambda,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
model = UNetLDMModel(
attention_resolutions=[4, 2, 1],
channel_mult=[1, 2, 4, 4],
context_dim=1280,
conv_resample=True,
dims=2,
dropout=0,
image_size=32,
in_channels=4,
model_channels=320,
num_heads=8,
num_res_blocks=2,
out_channels=4,
resblock_updown=False,
transformer_depth=1,
use_new_attention_order=False,
use_scale_shift_norm=False,
use_spatial_transformer=True,
legacy=False,
)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset(args.dataset, split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
if is_distributed:
torch.distributed.barrier()
# Generate a sample image for visual inspection
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if is_distributed:
torch.distributed.barrier()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
main(args)
......@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv=True,
resolution=args.resolution,
)
noise_scheduler = DDPMScheduler(timesteps=1000)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose(
......@@ -93,15 +93,13 @@ def main(args):
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
noise_samples = torch.empty_like(clean_images)
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device)
noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
......@@ -146,7 +144,7 @@ def main(args):
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png")
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:
......
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from diffusers.pipelines.pipeline_glide import GLIDE, CLIPTextModel
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -55,7 +55,7 @@ for layer_idx in range(config.num_hidden_layers):
### Convert the Text-to-Image UNet
text2im_model = GLIDETextToImageUNetModel(
text2im_model = GlideTextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
......@@ -80,7 +80,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
superres_model = GlideSuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
......@@ -101,7 +101,7 @@ upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = GLIDE(
glide = Glide(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
__version__ = "0.0.4"
......@@ -11,13 +11,19 @@ from .models.unet import UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, PNDM
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion
from .pipelines import GlidePipeline, LatentDiffusionPipeline
else:
from .utils.dummy_transformers_objects import *
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
from .pipelines import GradTTSPipeline
else:
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
......@@ -17,7 +17,7 @@
# limitations under the License.
from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet
......@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length)
class GLIDEUNetModel(ModelMixin, ConfigMixin):
class GlideUNetModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.
......@@ -641,7 +641,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel):
class GlideTextToImageUNetModel(GlideUNetModel):
"""
A UNetModel that performs super-resolution.
......@@ -734,7 +734,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel):
class GlideSuperResUNetModel(GlideUNetModel):
"""
A UNetModel that performs super-resolution.
......
......@@ -21,7 +21,6 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
......@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]}
self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module")
model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
......@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
# 2. Get class name and module candidates to load custom models
module_candidate_name = config_dict["_module"]
module_candidate = module_candidate_name + ".py"
# 3. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
......@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# (TODO - we should allow to load custom pipelines
# else we need to load the correct module from the Hub
# module = module_candidate
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
......@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin):
# import it here to avoid circular import
from diffusers import pipelines
# 4. Load each module in the pipeline
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline
......@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin):
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
elif library_name == module_candidate_name:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
......
......@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj)
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
from ..utils import is_transformers_available
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline
if is_transformers_available():
from .pipeline_glide import GLIDE
from .pipeline_grad_tts import GradTTS
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTSPipeline
......@@ -6,20 +6,9 @@ from shutil import copyfile
import torch
import inflect
from transformers import PreTrainedTokenizer
try:
from unidecode import unidecode
except:
print("unidecode is not installed")
pass
try:
import inflect
except:
print("inflect is not installed")
pass
from unidecode import unidecode
valid_symbols = [
......@@ -234,12 +223,7 @@ def english_cleaners(text):
return text
try:
_inflect = inflect.engine()
except:
print("inflect is not installed")
_inflect = None
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
......
......@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x)
class BDDM(DiffusionPipeline):
class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline):
class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline):
class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -30,7 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging
......@@ -694,7 +694,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE
#####################
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
......@@ -711,14 +710,14 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline):
class GlidePipeline(DiffusionPipeline):
def __init__(
self,
text_unet: GLIDETextToImageUNetModel,
text_unet: GlideTextToImageUNetModel,
text_noise_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_unet: GlideSuperResUNetModel,
upscale_noise_scheduler: DDIMScheduler,
):
super().__init__()
......
......@@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
return mu, logw, x_mask
class GradTTS(DiffusionPipeline):
class GradTTSPipeline(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......@@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline):
@torch.no_grad()
def __call__(
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
self,
text,
num_inference_steps=50,
temperature=1.3,
length_scale=0.91,
speaker_id=15,
torch_device=None,
generator=None,
):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -464,7 +471,7 @@ class GradTTS(DiffusionPipeline):
mu_y = mu_y.transpose(1, 2)
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
z = mu_y + torch.randn(mu_y.shape, device=mu_y.device, generator=generator) / temperature
xt = z * y_mask
h = 1.0 / num_inference_steps
......
......@@ -860,7 +860,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return dec, posterior
class LatentDiffusion(DiffusionPipeline):
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline
class PNDM(DiffusionPipeline):
class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -73,7 +73,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......@@ -132,7 +132,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE
# the residual is always re-derived from the clipped x_0 in Glide
residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment