"docker/diffusers-pytorch-compile-cuda/Dockerfile" did not exist on "58c6f9cb719cf6ee5fda9302801f3030c73b83a9"
Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union, Optional
from omegaconf import ListConfig, OmegaConf
from copy import deepcopy
import torch.nn.functional as F
from sat.helpers import print_rank0
import torch
from torch import nn
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
import gc
from sat import mpu
import random
class SATVideoDiffusionEngine(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
# model args preprocess
log_keys = model_config.get("log_keys", None)
input_key = model_config.get("input_key", "mp4")
network_config = model_config.get("network_config", None)
network_wrapper = model_config.get("network_wrapper", None)
denoiser_config = model_config.get("denoiser_config", None)
sampler_config = model_config.get("sampler_config", None)
conditioner_config = model_config.get("conditioner_config", None)
first_stage_config = model_config.get("first_stage_config", None)
loss_fn_config = model_config.get("loss_fn_config", None)
scale_factor = model_config.get("scale_factor", 1.0)
latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
lora_train = model_config.get("lora_train", False)
self.use_pd = model_config.get("use_pd", False) # progressive distillation
self.log_keys = log_keys
self.input_key = input_key
self.not_trainable_prefixes = not_trainable_prefixes
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.lora_train = lora_train
self.noised_image_input = model_config.get("noised_image_input", False)
self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
if args.fp16:
dtype = torch.float16
dtype_str = "fp16"
elif args.bf16:
dtype = torch.bfloat16
dtype_str = "bf16"
else:
dtype = torch.float32
dtype_str = "fp32"
self.dtype = dtype
self.dtype_str = dtype_str
network_config["params"]["dtype"] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model, dtype=dtype
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.latent_input = latent_input
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device
def disable_untrainable_params(self):
total_trainable = 0
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == "all":
flag = True
break
lora_prefix = ["matrix_A", "matrix_B", 'final_layer', 'proj_sr', 'local']
for prefix in lora_prefix:
if prefix in n:
flag = False
break
if flag:
p.requires_grad_(False)
else:
print(n)
total_trainable += p.numel()
print_rank0("***** Total trainable parameters: " + str(total_trainable / 1000000) + "M *****")
def reinit(self, parent_model=None):
# reload the initial params from previous trained modules
# you can also get access to other mixins through parent_model.get_mixin().
pass
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def forward(self, x, hq_video, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, hq_video, self.decode_first_stage)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
lr_z = self.encode_first_stage(lr_x, batch)
batch["lr_input"] = lr_z
x = x.permute(0, 2, 1, 3, 4).contiguous() # (B, T, C, H, W) -> (B, C, T, H, W)
hq_video = x # (B, C, T, H, W)
x = self.encode_first_stage(x, batch)
x = x.permute(0, 2, 1, 3, 4).contiguous() # (B, C, T, H, W) -> (B, T, C, H, W)
if 'lq' in batch.keys():
# print('LQ is NOT None')
lq = batch['lq'].to(self.dtype)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = self.encode_first_stage(lq, batch)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
batch['lq'] = lq
# Uncomment for t2v training,
# batch['lq'] = None
gc.collect()
torch.cuda.empty_cache()
loss, loss_dict = self(x, hq_video, batch)
return loss, loss_dict
def get_input(self, batch):
return batch[self.input_key].to(self.dtype)
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
use_cp = False
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x, batch=None):
frame = x.shape[2]
if frame > 1 and self.latent_input:
x = x.permute(0, 2, 1, 3, 4).contiguous()
return x * self.scale_factor # already encoded
use_cp = False
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
if hasattr(self, "seeded_noise"):
randn = self.seeded_noise(randn)
if prefix is not None:
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
if mp_size > 1:
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
scale = None
scale_emb = None
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def sample_sr(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
lq=None,
prefix=None,
concat_images=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
if hasattr(self, "seeded_noise"):
randn = self.seeded_noise(randn)
if prefix is not None:
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
if mp_size > 1:
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
scale = None
scale_emb = None
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
# add lq condition (new)
lq = lq.to(randn.device, self.dtype)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = self.encode_first_stage(lq)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = torch.cat((lq, lq), dim=0) # for CFG inference
# For T2V
# lq = None
# print('randn shape:', randn.shape) # torch.Size([1, 8, 16, 60, 90])
# print('lq shape:', lq.shape) # torch.Size([1, 8, 16, 60, 90])
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, lq=lq)
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[3:]
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_video(
self,
batch: Dict,
N: int = 8,
ucg_keys: List[str] = None,
only_log_video_latents=False,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
if not self.latent_input:
log["inputs"] = x.to(torch.float32)
x = x.permute(0, 2, 1, 3, 4).contiguous()
z = self.encode_first_stage(x, batch)
if not only_log_video_latents:
log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
z = z.permute(0, 2, 1, 3, 4).contiguous()
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
log["latents"] = latents
else:
samples = self.decode_first_stage(samples).to(torch.float32)
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log["samples"] = samples
return log
This diff is collapsed.
#! /bin/bash
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_sr.py --base configs/cogvideox_5b/cogvideox_5b_infer_sr.yaml"
echo ${run_cmd}
eval ${run_cmd}
echo "DONE on `hostname`"
\ No newline at end of file
SwissArmyTransformer==0.4.12
omegaconf==2.3.0
torch==2.4.0
torchvision==0.19.0
pytorch_lightning==2.3.3
kornia==0.7.3
beartype==0.18.5
numpy==2.0.1
fsspec==2024.5.0
safetensors==0.4.3
imageio-ffmpeg==0.5.1
imageio==2.34.2
# scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4
\ No newline at end of file
import os
import math
import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
import imageio
import torch
from einops import rearrange
import numpy as np
from einops import rearrange
import torchvision.transforms as TT
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from data_video import PairedCaptionDataset
from color_fix import adain_color_fix
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch):
gif_frames = []
for frame in vid:
frame = rearrange(frame, "c h w -> h w c")
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
gif_frames.append(frame)
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
with imageio.get_writer(now_save_path, fps=fps, quality=10) as writer:
for frame in gif_frames:
writer.append_data(frame)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls):
test_dataset = PairedCaptionDataset(data_dir='/mnt/bn/videodataset/VSR/dataset/VSRTest/cogvideox_test',
null_text_ratio=0, num_frames=25)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=8,
batch_size=1,
shuffle=False
)
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = [480, 720]
sample_func = model.sample_sr
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
with torch.no_grad():
for step, batch in enumerate(test_dataloader):
cnt = step
gt = batch['mp4']
text = batch['txt']
lq = batch['lq']
fps = batch['fps']
# reload model on GPU
model.to(device)
print("rank:", rank, "start to process", text, cnt)
# TODO: broadcast image2video
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
for index in range(args.batch_size):
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
lq=lq,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# print('max samples_z:', torch.max(samples_z)) # 3.0996
# print('min samples_z:', torch.min(samples_z)) # -3.0742
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
print('latent shape:', latent.shape)
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
# Using color fix
samples = adain_color_fix(samples, gt) # samples,lq: (b, t, c, h, w)
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
save_path_gt = os.path.join(
args.output_dir, str(cnt) + "_gt_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
save_path_lq = os.path.join(
args.output_dir, str(cnt) + "_lq_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=float(fps))
# save_video_as_grid_and_mp4(torch.clamp((gt + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_gt, fps=float(fps))
# save_video_as_grid_and_mp4(torch.clamp((lq + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_lq, fps=float(fps))
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
sampling_main(args, model_cls=SATVideoDiffusionEngine)
from .models import AutoencodingEngine
from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.0"
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = (
self.f_min[cycle]
+ (self.f_max[cycle] - self.f_min[cycle])
* (self.cycle_lengths[cycle] - n)
/ (self.cycle_lengths[cycle])
)
self.last_f = f
return f
from .autoencoder import AutoencodingEngine
This diff is collapsed.
from .encoders.modules import GeneralConditioner
UNCONDITIONAL_CONFIG = {
"target": "sgm.modules.GeneralConditioner",
"params": {"emb_models": []},
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment