Commit 17c574a1 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove torchvision dependency

parent f84bbd35
...@@ -144,9 +144,11 @@ if __name__ == "__main__": ...@@ -144,9 +144,11 @@ if __name__ == "__main__":
type=str, type=str,
default="no", default="no",
choices=["no", "fp16", "bf16"], choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose" help=(
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "Whether to use mixed precision. Choose"
"and an Nvidia Ampere GPU.", "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
) )
args = parser.parse_args() args = parser.parse_args()
......
...@@ -87,7 +87,6 @@ _deps = [ ...@@ -87,7 +87,6 @@ _deps = [
"regex!=2019.12.17", "regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"torchvision",
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
...@@ -172,7 +171,6 @@ install_requires = [ ...@@ -172,7 +171,6 @@ install_requires = [
deps["regex"], deps["regex"],
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["torchvision"],
deps["Pillow"], deps["Pillow"],
] ]
......
...@@ -6,10 +6,10 @@ __version__ = "0.0.3" ...@@ -6,10 +6,10 @@ __version__ = "0.0.3"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
...@@ -226,7 +226,7 @@ class ConfigMixin: ...@@ -226,7 +226,7 @@ class ConfigMixin:
return json.loads(text) return json.loads(text)
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
......
...@@ -13,5 +13,4 @@ deps = { ...@@ -13,5 +13,4 @@ deps = {
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchvision": "torchvision",
} }
...@@ -17,6 +17,6 @@ ...@@ -17,6 +17,6 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_grad_tts import UNetGradTTSModel
\ No newline at end of file
...@@ -26,7 +26,6 @@ from torch.optim import Adam ...@@ -26,7 +26,6 @@ from torch.optim import Adam
from torch.utils import data from torch.utils import data
from PIL import Image from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -331,171 +330,3 @@ class UNetModel(ModelMixin, ConfigMixin):
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
# dataset classes
class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# trainer class
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class Trainer(object):
def __init__(
self,
diffusion_model,
folder,
*,
ema_decay=0.995,
image_size=128,
train_batch_size=32,
train_lr=1e-4,
train_num_steps=100000,
gradient_accumulate_every=2,
amp=False,
step_start_ema=2000,
update_ema_every=10,
save_and_sample_every=1000,
results_folder="./results",
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.step_start_ema = step_start_ema
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.image_size = diffusion_model.image_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.ds = Dataset(folder, image_size)
self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True))
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
self.step = 0
self.amp = amp
self.scaler = GradScaler(enabled=amp)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok=True)
self.reset_parameters()
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
def save(self, milestone):
data = {
"step": self.step,
"model": self.model.state_dict(),
"ema": self.ema_model.state_dict(),
"scaler": self.scaler.state_dict(),
}
torch.save(data, str(self.results_folder / f"model-{milestone}.pt"))
def load(self, milestone):
data = torch.load(str(self.results_folder / f"model-{milestone}.pt"))
self.step = data["step"]
self.model.load_state_dict(data["model"])
self.ema_model.load_state_dict(data["ema"])
self.scaler.load_state_dict(data["scaler"])
def train(self):
with tqdm(initial=self.step, total=self.train_num_steps) as pbar:
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
with autocast(enabled=self.amp):
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
pbar.set_description(f"loss: {loss.item():.4f}")
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
if self.step % self.update_ema_every == 0:
self.step_ema()
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6)
self.save(milestone)
self.step += 1
pbar.update(1)
print("training complete")
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import torch import torch
try: try:
from einops import rearrange, repeat from einops import rearrange, repeat
except: except:
...@@ -11,6 +12,7 @@ except: ...@@ -11,6 +12,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
...@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module): ...@@ -47,9 +49,9 @@ class Rezero(torch.nn.Module):
class Block(torch.nn.Module): class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8): def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__() super(Block, self).__init__()
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, self.block = torch.nn.Sequential(
padding=1), torch.nn.GroupNorm( torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
groups, dim_out), Mish()) )
def forward(self, x, mask): def forward(self, x, mask):
output = self.block(x * mask) output = self.block(x * mask)
...@@ -59,8 +61,7 @@ class Block(torch.nn.Module): ...@@ -59,8 +61,7 @@ class Block(torch.nn.Module):
class ResnetBlock(torch.nn.Module): class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__() super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
dim_out))
self.block1 = Block(dim, dim_out, groups=groups) self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups)
...@@ -83,18 +84,16 @@ class LinearAttention(torch.nn.Module): ...@@ -83,18 +84,16 @@ class LinearAttention(torch.nn.Module):
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
heads = self.heads, qkv=3)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
heads=self.heads, h=h, w=w)
return self.to_out(out) return self.to_out(out)
...@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module): ...@@ -124,16 +123,7 @@ class SinusoidalPosEmb(torch.nn.Module):
class UNetGradTTSModel(ModelMixin, ConfigMixin): class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__( def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
self,
dim,
dim_mults=(1, 2, 4),
groups=8,
n_spks=None,
spk_emb_dim=64,
n_feats=80,
pe_scale=1000
):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
self.register( self.register(
...@@ -143,22 +133,22 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -143,22 +133,22 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
n_spks=n_spks, n_spks=n_spks,
spk_emb_dim=spk_emb_dim, spk_emb_dim=spk_emb_dim,
n_feats=n_feats, n_feats=n_feats,
pe_scale=pe_scale pe_scale=pe_scale,
) )
self.dim = dim self.dim = dim
self.dim_mults = dim_mults self.dim_mults = dim_mults
self.groups = groups self.groups = groups
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
self.spk_emb_dim = spk_emb_dim self.spk_emb_dim = spk_emb_dim
self.pe_scale = pe_scale self.pe_scale = pe_scale
if n_spks > 1: if n_spks > 1:
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim * 4, n_feats)) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim) self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
...@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -168,11 +158,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(torch.nn.ModuleList([ self.downs.append(
ResnetBlock(dim_in, dim_out, time_emb_dim=dim), torch.nn.ModuleList(
ResnetBlock(dim_out, dim_out, time_emb_dim=dim), [
Residual(Rezero(LinearAttention(dim_out))), ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
Downsample(dim_out) if not is_last else torch.nn.Identity()])) ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity(),
]
)
)
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
...@@ -180,18 +175,23 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -180,18 +175,23 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(torch.nn.ModuleList([ self.ups.append(
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), torch.nn.ModuleList(
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), [
Residual(Rezero(LinearAttention(dim_in))), ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
Upsample(dim_in)])) ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in),
]
)
)
self.final_block = Block(dim, dim) self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1) self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None): def forward(self, x, mask, mu, t, spk=None):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(t, scale=self.pe_scale) t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
...@@ -230,4 +230,4 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -230,4 +230,4 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
x = self.final_block(x, mask) x = self.final_block(x, mask)
output = self.final_conv(x * mask) output = self.final_conv(x * mask)
return (output * mask).squeeze(1) return (output * mask).squeeze(1)
\ No newline at end of file
...@@ -57,14 +57,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -57,14 +57,14 @@ class DiffusionPipeline(ConfigMixin):
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
for name, module in kwargs.items(): for name, module in kwargs.items():
# check if the module is a pipeline module # check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1]) is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library # retrive library
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline # Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name. # so we set the library to module name.
...@@ -160,10 +160,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -160,10 +160,10 @@ class DiffusionPipeline(ConfigMixin):
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
# 4. Load each module in the pipeline # 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
......
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM from .pipeline_pndm import PNDM
...@@ -283,7 +283,7 @@ class BDDM(DiffusionPipeline): ...@@ -283,7 +283,7 @@ class BDDM(DiffusionPipeline):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device) self.diffwave.to(torch_device)
mel_spectrogram = mel_spectrogram.to(torch_device) mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256 audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length) audio_size = (1, 1, audio_length)
......
...@@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline): ...@@ -832,9 +832,7 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = torch.randn( image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
).to(torch_device)
# 2. Encode tokens # 2. Encode tokens
# an empty input is needed to guide the model away from it # an empty input is needed to guide the model away from it
......
...@@ -39,14 +39,13 @@ def generate_path(duration, mask): ...@@ -39,14 +39,13 @@ def generate_path(duration, mask):
cum_duration_flat = cum_duration.view(b * t_x) cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y) path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
[1, 0], [0, 0]]))[:, :-1]
path = path * mask path = path * mask
return path return path
def duration_loss(logw, logw_, lengths): def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
return loss return loss
...@@ -62,7 +61,7 @@ class LayerNorm(nn.Module): ...@@ -62,7 +61,7 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
n_dims = len(x.shape) n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True) mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True) variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps) x = (x - mean) * torch.rsqrt(variance + self.eps)
...@@ -72,8 +71,7 @@ class LayerNorm(nn.Module): ...@@ -72,8 +71,7 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module): class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
n_layers, p_dropout):
super(ConvReluNorm, self).__init__() super(ConvReluNorm, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
...@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module): ...@@ -84,13 +82,13 @@ class ConvReluNorm(nn.Module):
self.conv_layers = torch.nn.ModuleList() self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, self.conv_layers.append(
kernel_size, padding=kernel_size//2)) torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
...@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module): ...@@ -114,11 +112,9 @@ class DurationPredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
kernel_size, padding=kernel_size//2)
self.norm_1 = LayerNorm(filter_channels) self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
kernel_size, padding=kernel_size//2)
self.norm_2 = LayerNorm(filter_channels) self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1) self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
...@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module): ...@@ -136,9 +132,17 @@ class DurationPredictor(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None, def __init__(
heads_share=True, p_dropout=0.0, proximal_bias=False, self,
proximal_init=False): channels,
out_channels,
n_heads,
window_size=None,
heads_share=True,
p_dropout=0.0,
proximal_bias=False,
proximal_init=False,
):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0 assert channels % n_heads == 0
...@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module): ...@@ -158,10 +162,12 @@ class MultiHeadAttention(nn.Module):
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, self.emb_rel_k = torch.nn.Parameter(
window_size * 2 + 1, self.k_channels) * rel_stddev) torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, )
window_size * 2 + 1, self.k_channels) * rel_stddev) self.emb_rel_v = torch.nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
...@@ -171,12 +177,12 @@ class MultiHeadAttention(nn.Module): ...@@ -171,12 +177,12 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.data.copy_(self.conv_q.weight.data) self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data) self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight) torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None): def forward(self, x, c, attn_mask=None):
q = self.conv_q(x) q = self.conv_q(x)
k = self.conv_k(c) k = self.conv_k(c)
v = self.conv_v(c) v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask) x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x) x = self.conv_o(x)
...@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module): ...@@ -198,8 +204,7 @@ class MultiHeadAttention(nn.Module):
scores = scores + scores_local scores = scores + scores_local
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
dtype=scores.dtype)
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1) p_attn = torch.nn.functional.softmax(scores, dim=-1)
...@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module): ...@@ -208,8 +213,7 @@ class MultiHeadAttention(nn.Module):
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn return output, p_attn
...@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module): ...@@ -227,28 +231,27 @@ class MultiHeadAttention(nn.Module):
slice_end_position = slice_start_position + 2 * length - 1 slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0: if pad_length > 0:
padded_relative_embeddings = torch.nn.functional.pad( padded_relative_embeddings = torch.nn.functional.pad(
relative_embeddings, convert_pad_shape([[0, 0], relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])
[pad_length, pad_length], [0, 0]])) )
else: else:
padded_relative_embeddings = relative_embeddings padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:, used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
slice_start_position:slice_end_position]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final return x_final
def _attention_bias_proximal(self, length): def _attention_bias_proximal(self, length):
...@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module): ...@@ -258,8 +261,7 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
p_dropout=0.0):
super(FFN, self).__init__() super(FFN, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -267,10 +269,8 @@ class FFN(nn.Module): ...@@ -267,10 +269,8 @@ class FFN(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
padding=kernel_size//2) self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
padding=kernel_size//2)
self.drop = torch.nn.Dropout(p_dropout) self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask): def forward(self, x, x_mask):
...@@ -282,8 +282,17 @@ class FFN(nn.Module): ...@@ -282,8 +282,17 @@ class FFN(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, def __init__(
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=None,
**kwargs,
):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
...@@ -299,11 +308,15 @@ class Encoder(nn.Module): ...@@ -299,11 +308,15 @@ class Encoder(nn.Module):
self.ffn_layers = torch.nn.ModuleList() self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers): for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, self.attn_layers.append(
n_heads, window_size=window_size, p_dropout=p_dropout)) MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, self.ffn_layers.append(
filter_channels, kernel_size, p_dropout=p_dropout)) FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
)
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask): def forward(self, x, x_mask):
...@@ -321,9 +334,21 @@ class Encoder(nn.Module): ...@@ -321,9 +334,21 @@ class Encoder(nn.Module):
class TextEncoder(ModelMixin, ConfigMixin): class TextEncoder(ModelMixin, ConfigMixin):
def __init__(self, n_vocab, n_feats, n_channels, filter_channels, def __init__(
filter_channels_dp, n_heads, n_layers, kernel_size, self,
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): n_vocab,
n_feats,
n_channels,
filter_channels,
filter_channels_dp,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=None,
spk_emb_dim=64,
n_spks=1,
):
super(TextEncoder, self).__init__() super(TextEncoder, self).__init__()
self.register( self.register(
...@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -338,10 +363,9 @@ class TextEncoder(ModelMixin, ConfigMixin):
p_dropout=p_dropout, p_dropout=p_dropout,
window_size=window_size, window_size=window_size,
spk_emb_dim=spk_emb_dim, spk_emb_dim=spk_emb_dim,
n_spks=n_spks n_spks=n_spks,
) )
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.n_feats = n_feats self.n_feats = n_feats
self.n_channels = n_channels self.n_channels = n_channels
...@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -358,15 +382,22 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.emb = torch.nn.Embedding(n_vocab, n_channels) self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5)
kernel_size=5, n_layers=3, p_dropout=0.5)
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, self.encoder = Encoder(
kernel_size, p_dropout, window_size=window_size) n_channels + (spk_emb_dim if n_spks > 1 else 0),
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
)
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, self.proj_w = DurationPredictor(
kernel_size, p_dropout) n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
)
def forward(self, x, x_lengths, spk=None): def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels) x = self.emb(x) * math.sqrt(self.n_channels)
......
...@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_predicted_image=clip_predicted_image, clip_predicted_image=clip_predicted_image,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
self.variance_type = variance_type self.variance_type = variance_type
......
...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps] return self.warmup_time_steps[num_inference_steps]
...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at = alphas_cump[t + 1].view(-1, 1, 1, 1) at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) x_delta = (at_next - at) * (
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
)
x_next = x + x_delta x_next = x + x_delta
return x_next return x_next
......
...@@ -19,7 +19,18 @@ import unittest ...@@ -19,7 +19,18 @@ import unittest
import torch import torch
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers import (
BDDM,
DDIM,
DDPM,
GLIDE,
PNDM,
DDIMScheduler,
DDPMScheduler,
LatentDiffusion,
PNDMScheduler,
UNetModel,
)
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.pipelines.pipeline_bddm import DiffWave
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment