Commit 9fdbc14e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents ef044a72 d754ce5f
# References
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
\ No newline at end of file
import argparse
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu") state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig( config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512, hidden_size=512,
intermediate_size=2048, intermediate_size=2048,
num_hidden_layers=16, num_hidden_layers=16,
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=128 use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model hf_encoder = model.text_model
...@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] ...@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers): for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx] hf_layer = hf_encoder.encoder.layers[layer_idx]
q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.q_proj.weight.data = q_proj
hf_layer.self_attn.q_proj.bias.data = q_proj_bias
hf_layer.self_attn.k_proj.weight.data = k_proj
hf_layer.self_attn.k_proj.bias.data = k_proj_bias
hf_layer.self_attn.v_proj.weight.data = v_proj
hf_layer.self_attn.v_proj.bias.data = v_proj_bias
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
...@@ -53,8 +51,28 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -53,8 +51,28 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") ### Convert the UNet
with torch.no_grad():
outputs = model(**inputs) unet_model = UNetGLIDEModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
unet_model.load_state_dict(state_dict, strict=False)
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
model.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
\ No newline at end of file
...@@ -14,46 +14,154 @@ ...@@ -14,46 +14,154 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline import numpy as np
from diffusers import UNetGLIDEModel import torch
import tqdm import tqdm
import torch from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
from transformers import GPT2Tokenizer
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__(self, unet: UNetGLIDEModel, noise_scheduler): def __init__(
self,
unet: UNetGLIDEModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
def __call__(self, generator=None, torch_device=None): B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) self.unet.to(torch_device)
self.text_encoder.to(torch_device)
# Create a classifier-free guidance sampling function
guidance_scale = 3.0
def model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) batch_size = 2 # second image is empty for classifier-free guidance
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): image = self.noise_scheduler.sample_noise(
# i) define coefficients for time step t (batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) )
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) # 2. Encode tokens
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) # an empty input is needed to guide the model away from (
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
# ii) predict noise residual input_ids = inputs["input_ids"].to(torch_device)
with torch.no_grad(): attention_mask = inputs["attention_mask"].to(torch_device)
noise_residual = self.unet(image, t) transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# iii) compute predicted image from residual num_timesteps = len(self.noise_scheduler)
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual t = torch.tensor([i] * image.shape[0], device=torch_device)
pred_mean = torch.clamp(pred_mean, -1, 1) mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out)
prev_image = clip_coeff * pred_mean + image_coeff * image noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
# iv) sample variance image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
return image return image
import torch import torch
from .modeling_glide import GLIDE
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler from modeling_glide import GLIDE
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
# 1. Load models # 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base")
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") img = pipeline("an oil painting of a corgi", generator)
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
pipeline = GLIDE(model, scheduler)
img = pipeline(generator)
print(img) print(img)
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
__version__ = "0.0.1" __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import UNetGLIDEModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -89,7 +89,6 @@ class ConfigMixin: ...@@ -89,7 +89,6 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}") logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod @classmethod
def get_config_dict( def get_config_dict(
...@@ -183,7 +182,7 @@ class ConfigMixin: ...@@ -183,7 +182,7 @@ class ConfigMixin:
logger.info(f"loading configuration file {config_file}") logger.info(f"loading configuration file {config_file}")
else: else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict return config_dict
@classmethod @classmethod
...@@ -199,9 +198,8 @@ class ConfigMixin: ...@@ -199,9 +198,8 @@ class ConfigMixin:
# use value from config dict # use value from config dict
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs) unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
...@@ -212,9 +210,7 @@ class ConfigMixin: ...@@ -212,9 +210,7 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict( config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
......
...@@ -16,5 +16,7 @@ ...@@ -16,5 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import UNetGLIDEModel
from .unet_ldm import UNetLDMModel
This diff is collapsed.
...@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
encoder_channels=None, transformer_dim=512,
): ):
super().__init__() super().__init__()
self.register( self.register(
...@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
encoder_channels=encoder_channels, transformer_dim=transformer_dim,
) )
if num_heads_upsample == -1: if num_heads_upsample == -1:
...@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.channel_mult = channel_mult self.channel_mult = channel_mult
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32 # self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
ch = input_ch = int(channel_mult[0] * model_channels) ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch self._feature_size = ch
...@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads, num_heads=num_heads,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
...@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads, num_heads=num_heads,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
), ),
ResBlock( ResBlock(
ch, ch,
...@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
) )
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
...@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
...@@ -653,13 +651,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -653,13 +651,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
h = x.type(self.dtype) emb = emb + transformer_proj.to(emb)
h = x
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb, transformer_out)
hs.append(h) hs.append(h)
h = self.middle_block(h, emb) h = self.middle_block(h, emb, transformer_out)
for module in self.output_blocks: for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1) other = hs.pop()
h = module(h, emb) h = torch.cat([h, other], dim=1)
h = h.type(x.dtype) h = module(h, emb, transformer_out)
return self.out(h) return self.out(h)
...@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32 self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
assert y.shape == (x.shape[0],) assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype_)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb, context) h = module(h, emb, context)
hs.append(h) hs.append(h)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import importlib import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
...@@ -35,10 +36,12 @@ logger = logging.get_logger(__name__) ...@@ -35,10 +36,12 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"], "ModelMixin": ["save_pretrained", "from_pretrained"],
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"ModelMixin": ["save_pretrained", "from_pretrained"], "GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
}, },
} }
...@@ -62,7 +65,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -62,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict) self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
......
...@@ -16,4 +16,5 @@ ...@@ -16,4 +16,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float64)
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_schedule="squaredcos_cap_v2",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_schedule=beta_schedule,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
self.betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment