"...text-generation-inference.git" did not exist on "abd58ff82c37d5e4f131abdac3d298927a815604"
Commit f7ce79f8 authored by anton-l's avatar anton-l
Browse files

+ cosine schedule and unet config

parent 111fa990
import torch
from .modeling_glide import GLIDE
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler
generator = torch.Generator()
generator = generator.manual_seed(0)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
pipeline = GLIDE(model, scheduler)
img = pipeline(generator)
print(img)
import math import math
from abc import abstractmethod from abc import abstractmethod
import torch as th import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..configuration_utils import Config
from ..modeling_utils import PreTrainedModel
def convert_module_to_f16(l): def convert_module_to_f16(l):
""" """
...@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000): ...@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings. :return: an [N x dim] Tensor of positional embeddings.
""" """
half = dim // 2 half = dim // 2
freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device device=timesteps.device
) )
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding return embedding
...@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock): ...@@ -298,7 +301,7 @@ class ResBlock(TimestepBlock):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
if self.use_scale_shift_norm: if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1) scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift h = out_norm(h) * (1 + scale) + shift
h = out_rest(h) h = out_rest(h)
else: else:
...@@ -376,16 +379,16 @@ class QKVAttention(nn.Module): ...@@ -376,16 +379,16 @@ class QKVAttention(nn.Module):
if encoder_kv is not None: if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2 assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = th.cat([ek, k], dim=-1) k = torch.cat([ek, k], dim=-1)
v = th.cat([ev, v], dim=-1) v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v) a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class UNetGLIDEModel(nn.Module): class UNetGLIDEModel(PreTrainedModel, Config):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module): ...@@ -435,6 +438,25 @@ class UNetGLIDEModel(nn.Module):
encoder_channels=None, encoder_channels=None,
): ):
super().__init__() super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
encoder_channels=encoder_channels,
)
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
...@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module): ...@@ -448,7 +470,7 @@ class UNetGLIDEModel(nn.Module):
self.channel_mult = channel_mult self.channel_mult = channel_mult
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module): ...@@ -637,7 +659,7 @@ class UNetGLIDEModel(nn.Module):
hs.append(h) hs.append(h)
h = self.middle_block(h, emb) h = self.middle_block(h, emb)
for module in self.output_blocks: for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import math
from torch import nn from torch import nn
from ..configuration_utils import Config from ..configuration_utils import Config
...@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end): ...@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float64)
class GaussianDDPMScheduler(nn.Module, Config): class GaussianDDPMScheduler(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME config_name = SAMPLING_CONFIG_NAME
...@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, Config): ...@@ -48,6 +69,12 @@ class GaussianDDPMScheduler(nn.Module, Config):
if beta_schedule == "linear": if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment