Commit e29fc446 authored by Nathan Lambert's avatar Nathan Lambert
Browse files
parents 7b4e049e 6e456b7a
...@@ -48,7 +48,7 @@ The class provides functionality to compute previous image according to alpha, b ...@@ -48,7 +48,7 @@ The class provides functionality to compute previous image according to alpha, b
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ... **Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2 *Examples*: Glide, Latent-Diffusion, Imagen, DALL-E 2
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/> <img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
...@@ -190,7 +190,7 @@ image_pil.save("test.png") ...@@ -190,7 +190,7 @@ image_pil.save("test.png")
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing) [Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...) ### 2. `diffusers` as a collection of popular Diffusion systems (Glide, Dalle, ...)
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
...@@ -249,24 +249,24 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -249,24 +249,24 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Text to speech with GradTTS and BDDM** #### **Text to speech with GradTTS and BDDMPipeline**
```python ```python
import torch import torch
from diffusers import BDDM, GradTTS from diffusers import BDDMPipeline, GradTTSPipeline
torch_device = "cuda" torch_device = "cuda"
# load grad tts and bddm pipelines # load grad tts and bddm pipelines
grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts") grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech") bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
text = "Hello world, I missed you so much." text = "Hello world, I missed you so much."
# generate mel spectograms using text # generate mel spectograms using text
mel_spec = grad_tts(text, torch_device=torch_device) mel_spec = grad_tts(text, torch_device=torch_device)
# generate the speech by passing mel spectograms to BDDM pipeline # generate the speech by passing mel spectograms to BDDMPipeline pipeline
generator = torch.manual_seed(42) generator = torch.manual_seed(42)
audio = bddm(mel_spec, generator, torch_device=torch_device) audio = bddm(mel_spec, generator, torch_device=torch_device)
...@@ -278,13 +278,14 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) ...@@ -278,13 +278,14 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
## TODO ## TODO
- Create common API for models [ ] - [ ] Create common API for models
- Add tests for models [ ] - [ ] Add tests for models
- Adapt schedulers for training [ ] - [ ] Adapt schedulers for training
- Write google colab for training [ ] - [ ] Write google colab for training
- Write docs / Think about how to structure docs [ ] - [ ] Write docs / Think about how to structure docs
- Add tests to circle ci [ ] - [ ] Add tests to circle ci
- Add [Diffusion LM models](https://arxiv.org/pdf/2205.14217.pdf) [ ] - [ ] Add [Diffusion LM models](https://arxiv.org/pdf/2205.14217.pdf)
- Add more vision models [ ] - [ ] Add more vision models
- Add more speech models [ ] - [ ] Add more speech models
- Add RL model [ ] - [ ] Add RL model
- [ ] Add FID and KID metrics
...@@ -10,7 +10,7 @@ python -m torch.distributed.launch \ ...@@ -10,7 +10,7 @@ python -m torch.distributed.launch \
train_unconditional.py \ train_unconditional.py \
--dataset="huggan/flowers-102-categories" \ --dataset="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 \
--output_path="flowers-ddpm" \ --output_dir="flowers-ddpm" \
--batch_size=16 \ --batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \
...@@ -34,7 +34,7 @@ python -m torch.distributed.launch \ ...@@ -34,7 +34,7 @@ python -m torch.distributed.launch \
train_unconditional.py \ train_unconditional.py \
--dataset="huggan/pokemon" \ --dataset="huggan/pokemon" \
--resolution=64 \ --resolution=64 \
--output_path="pokemon-ddpm" \ --output_dir="pokemon-ddpm" \
--batch_size=16 \ --batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \
......
import argparse
import os
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetLDMModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Lambda,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
model = UNetLDMModel(
attention_resolutions=[4, 2, 1],
channel_mult=[1, 2, 4, 4],
context_dim=1280,
conv_resample=True,
dims=2,
dropout=0,
image_size=32,
in_channels=4,
model_channels=320,
num_heads=8,
num_res_blocks=2,
out_channels=4,
resblock_updown=False,
transformer_depth=1,
use_new_attention_order=False,
use_scale_shift_norm=False,
use_spatial_transformer=True,
legacy=False,
)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset(args.dataset, split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
if is_distributed:
torch.distributed.barrier()
# Generate a sample image for visual inspection
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if is_distributed:
torch.distributed.barrier()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
main(args)
...@@ -39,7 +39,7 @@ def main(args): ...@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv=True, resamp_with_conv=True,
resolution=args.resolution, resolution=args.resolution,
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose( augmentations = Compose(
...@@ -93,15 +93,13 @@ def main(args): ...@@ -93,15 +93,13 @@ def main(args):
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
noise_samples = torch.empty_like(clean_images)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device) # add noise onto the clean images according to the noise magnitude at each timestep
noise_samples[idx] = noise # (this is the forward diffusion process)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0: if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model): with accelerator.no_sync(model):
...@@ -146,7 +144,7 @@ def main(args): ...@@ -146,7 +144,7 @@ def main(args):
# save image # save image
test_dir = os.path.join(args.output_dir, "test_samples") test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png") image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model # save the model
if args.push_to_hub: if args.push_to_hub:
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import GLIDE, CLIPTextModel from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -55,7 +55,7 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -55,7 +55,7 @@ for layer_idx in range(config.num_hidden_layers):
### Convert the Text-to-Image UNet ### Convert the Text-to-Image UNet
text2im_model = GLIDETextToImageUNetModel( text2im_model = GlideTextToImageUNetModel(
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -80,7 +80,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule=" ...@@ -80,7 +80,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu") ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel( superres_model = GlideSuperResUNetModel(
in_channels=6, in_channels=6,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -101,7 +101,7 @@ upscale_scheduler = DDIMScheduler( ...@@ -101,7 +101,7 @@ upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt" timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
) )
glide = GLIDE( glide = Glide(
text_unet=text2im_model, text_unet=text2im_model,
text_noise_scheduler=text_scheduler, text_noise_scheduler=text_scheduler,
text_encoder=model, text_encoder=model,
......
# flake8: noqa # flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
__version__ = "0.0.4" __version__ = "0.0.4"
...@@ -11,13 +11,19 @@ from .models.unet import UNetModel ...@@ -11,13 +11,19 @@ from .models.unet import UNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, PNDM from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
if is_transformers_available(): if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion from .pipelines import GlidePipeline, LatentDiffusionPipeline
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
from .pipelines import GradTTSPipeline
else:
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class GLIDEUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -641,7 +641,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -641,7 +641,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
return self.out(h) return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel): class GlideTextToImageUNetModel(GlideUNetModel):
""" """
A UNetModel that performs super-resolution. A UNetModel that performs super-resolution.
...@@ -734,7 +734,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -734,7 +734,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
return self.out(h) return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel): class GlideSuperResUNetModel(GlideUNetModel):
""" """
A UNetModel that performs super-resolution. A UNetModel that performs super-resolution.
......
...@@ -21,7 +21,6 @@ from typing import Optional, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
...@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]}
self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)
model_index_dict = dict(self.config) model_index_dict = dict(self.config)
model_index_dict.pop("_class_name") model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module") model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys(): for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
...@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
# 2. Get class name and module candidates to load custom models # 2. Load the pipeline class, if using custom module then load it from the hub
module_candidate_name = config_dict["_module"]
module_candidate = module_candidate_name + ".py"
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
pipeline_class = cls pipeline_class = cls
...@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# (TODO - we should allow to load custom pipelines
# else we need to load the correct module from the Hub
# module = module_candidate
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
...@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin):
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
# 4. Load each module in the pipeline # 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline # if the model is in a pipeline module, then we load it from the pipeline
...@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin):
pipeline_module = getattr(pipelines, library_name) pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} class_candidates = {c: class_obj for c in importable_classes.keys()}
elif library_name == module_candidate_name:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else: else:
# else we just import it from the library. # else we just import it from the library.
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
......
...@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj) ...@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj)
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py). - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py). - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). - BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py). - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
from ..utils import is_transformers_available from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDMPipeline
if is_transformers_available(): if is_transformers_available():
from .pipeline_glide import GLIDE from .pipeline_glide import GlidePipeline
from .pipeline_grad_tts import GradTTS from .pipeline_latent_diffusion import LatentDiffusionPipeline
from .pipeline_latent_diffusion import LatentDiffusion
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTSPipeline
...@@ -6,20 +6,9 @@ from shutil import copyfile ...@@ -6,20 +6,9 @@ from shutil import copyfile
import torch import torch
import inflect
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from unidecode import unidecode
try:
from unidecode import unidecode
except:
print("unidecode is not installed")
pass
try:
import inflect
except:
print("inflect is not installed")
pass
valid_symbols = [ valid_symbols = [
...@@ -234,12 +223,7 @@ def english_cleaners(text): ...@@ -234,12 +223,7 @@ def english_cleaners(text):
return text return text
try: _inflect = inflect.engine()
_inflect = inflect.engine()
except:
print("inflect is not installed")
_inflect = None
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
......
...@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin): ...@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x) return self.final_conv(x)
class BDDM(DiffusionPipeline): class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler): def __init__(self, diffwave, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -21,7 +21,7 @@ import tqdm ...@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -21,7 +21,7 @@ import tqdm ...@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo ...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging from ..utils import logging
...@@ -694,7 +694,6 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -694,7 +694,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE # END OF THE CLIP MODEL COPY-PASTE
##################### #####################
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
""" """
Extract values from a 1-D numpy array for a batch of indices. Extract values from a 1-D numpy array for a batch of indices.
...@@ -711,14 +710,14 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -711,14 +710,14 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res + torch.zeros(broadcast_shape, device=timesteps.device) return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline): class GlidePipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
text_unet: GLIDETextToImageUNetModel, text_unet: GlideTextToImageUNetModel,
text_noise_scheduler: DDPMScheduler, text_noise_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GlideSuperResUNetModel,
upscale_noise_scheduler: DDIMScheduler, upscale_noise_scheduler: DDIMScheduler,
): ):
super().__init__() super().__init__()
......
...@@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
return mu, logw, x_mask return mu, logw, x_mask
class GradTTS(DiffusionPipeline): class GradTTSPipeline(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
...@@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline): ...@@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None self,
text,
num_inference_steps=50,
temperature=1.3,
length_scale=0.91,
speaker_id=15,
torch_device=None,
generator=None,
): ):
if torch_device is None: if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...@@ -464,7 +471,7 @@ class GradTTS(DiffusionPipeline): ...@@ -464,7 +471,7 @@ class GradTTS(DiffusionPipeline):
mu_y = mu_y.transpose(1, 2) mu_y = mu_y.transpose(1, 2)
# Sample latent representation from terminal distribution N(mu_y, I) # Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature z = mu_y + torch.randn(mu_y.shape, device=mu_y.device, generator=generator) / temperature
xt = z * y_mask xt = z * y_mask
h = 1.0 / num_inference_steps h = 1.0 / num_inference_steps
......
...@@ -860,7 +860,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -860,7 +860,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return dec, posterior return dec, posterior
class LatentDiffusion(DiffusionPipeline): class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -21,7 +21,7 @@ import tqdm ...@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class PNDM(DiffusionPipeline): class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -73,7 +73,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -73,7 +73,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -132,7 +132,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -132,7 +132,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_residual: if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE # the residual is always re-derived from the clipped x_0 in Glide
residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment