Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import math
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
__all__ = ['Adafactor']
class Adafactor(Optimizer):
"""
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`.
Arguments:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*):
The external learning rate.
eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
Regularization constants for square gradient and parameter scale respectively
clip_threshold (`float`, *optional*, defaults 1.0):
Threshold of root mean square of final gradient update
decay_rate (`float`, *optional*, defaults to -0.8):
Coefficient used to compute running averages of square
beta1 (`float`, *optional*):
Coefficient used for computing running averages of gradient
weight_decay (`float`, *optional*, defaults to 0):
Weight decay (L2 penalty)
scale_parameter (`bool`, *optional*, defaults to `True`):
If True, learning rate is scaled by root mean square
relative_step (`bool`, *optional*, defaults to `True`):
If True, time-dependent learning rate is computed instead of external learning rate
warmup_init (`bool`, *optional*, defaults to `False`):
Time-dependent learning rate computation depends on whether warm-up initialization is being used
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
- Training without LR warmup or clip_threshold is not recommended.
- use scheduled LR warm-up to fixed LR
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
- Disable relative updates
- Use scale_parameter=False
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
Example:
```python
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
```
Others reported the following combination to work well:
```python
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
```
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
scheduler as following:
```python
from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
```
Usage:
```python
# replace AdamW with Adafactor
optimizer = Adafactor(
model.parameters(),
lr=1e-3,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
```"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=True,
relative_step=True,
warmup_init=False,
):
r"""require_version("torch>=1.5.0") # add_ with alpha
"""
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = dict(
lr=lr,
eps=eps,
clip_threshold=clip_threshold,
decay_rate=decay_rate,
beta1=beta1,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step=relative_step,
warmup_init=warmup_init,
)
super().__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
import math
from torch.optim.lr_scheduler import _LRScheduler
__all__ = ['AnnealingLR']
class AnnealingLR(_LRScheduler):
def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1):
assert decay_mode in ['linear', 'cosine', 'none']
self.optimizer = optimizer
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.decay_mode = decay_mode
self.min_lr = min_lr
self.current_step = last_step + 1
self.step(self.current_step)
def get_lr(self):
if self.warmup_steps > 0 and self.current_step <= self.warmup_steps:
return self.base_lr * self.current_step / self.warmup_steps
else:
ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
ratio = min(1.0, max(0.0, ratio))
if self.decay_mode == 'linear':
return self.base_lr * (1 - ratio)
elif self.decay_mode == 'cosine':
return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0
else:
return self.base_lr
def step(self, current_step=None):
if current_step is None:
current_step = self.current_step + 1
self.current_step = current_step
new_lr = max(self.min_lr, self.get_lr())
if isinstance(self.optimizer, list):
for o in self.optimizer:
for group in o.param_groups:
group['lr'] = new_lr
else:
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
return {
'base_lr': self.base_lr,
'warmup_steps': self.warmup_steps,
'total_steps': self.total_steps,
'decay_mode': self.decay_mode,
'current_step': self.current_step}
def load_state_dict(self, state_dict):
self.base_lr = state_dict['base_lr']
self.warmup_steps = state_dict['warmup_steps']
self.total_steps = state_dict['total_steps']
self.decay_mode = state_dict['decay_mode']
self.current_step = state_dict['current_step']
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
# Registry class & build_from_config function partially modified from
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
# Copyright 2018-2020 Open-MMLab. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import warnings
def build_from_config(cfg, registry, **kwargs):
""" Default builder function.
Args:
cfg (dict): A dict which contains parameters passes to target class or function.
Must contains key 'type', indicates the target class or function name.
registry (Registry): An registry to search target class or function.
kwargs (dict, optional): Other params not in config dict.
Returns:
Target class object or object returned by invoking function.
Raises:
TypeError:
KeyError:
Exception:
"""
if not isinstance(cfg, dict):
raise TypeError(f"config must be type dict, got {type(cfg)}")
if "type" not in cfg:
raise KeyError(f"config must contain key type, got {cfg}")
if not isinstance(registry, Registry):
raise TypeError(f"registry must be type Registry, got {type(registry)}")
cfg = copy.deepcopy(cfg)
req_type = cfg.pop("type")
req_type_entry = req_type
if isinstance(req_type, str):
req_type_entry = registry.get(req_type)
if req_type_entry is None:
raise KeyError(f"{req_type} not found in {registry.name} registry")
if kwargs is not None:
cfg.update(kwargs)
if inspect.isclass(req_type_entry):
try:
return req_type_entry(**cfg)
except Exception as e:
raise Exception(f"Failed to init class {req_type_entry}, with {e}")
elif inspect.isfunction(req_type_entry):
try:
return req_type_entry(**cfg)
except Exception as e:
raise Exception(f"Failed to invoke function {req_type_entry}, with {e}")
else:
raise TypeError(f"type must be str or class, got {type(req_type_entry)}")
class Registry(object):
""" A registry maps key to classes or functions.
Example:
>>> MODELS = Registry('MODELS')
>>> @MODELS.register_class()
>>> class ResNet(object):
>>> pass
>>> resnet = MODELS.build(dict(type="ResNet"))
>>>
>>> import torchvision
>>> @MODELS.register_function("InceptionV3")
>>> def get_inception_v3(pretrained=False, progress=True):
>>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress)
>>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True))
Args:
name (str): Registry name.
build_func (func, None): Instance construct function. Default is build_from_config.
allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function.
"""
def __init__(self, name, build_func=None, allow_types=("class", "function")):
self.name = name
self.allow_types = allow_types
self.class_map = {}
self.func_map = {}
self.build_func = build_func or build_from_config
def get(self, req_type):
return self.class_map.get(req_type) or self.func_map.get(req_type)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def register_class(self, name=None):
def _register(cls):
if not inspect.isclass(cls):
raise TypeError(f"Module must be type class, got {type(cls)}")
if "class" not in self.allow_types:
raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class")
module_name = name or cls.__name__
if module_name in self.class_map:
warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, "
f"will be replaced by {cls}")
self.class_map[module_name] = cls
return cls
return _register
def register_function(self, name=None):
def _register(func):
if not inspect.isfunction(func):
raise TypeError(f"Registry must be type function, got {type(func)}")
if "function" not in self.allow_types:
raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function")
func_name = name or func.__name__
if func_name in self.class_map:
warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, "
f"will be replaced by {func}")
self.func_map[func_name] = func
return func
return _register
def _list(self):
keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys()))
descriptions = []
for key in keys:
if key in self.class_map:
descriptions.append(f"{key}: {self.class_map[key]}")
else:
descriptions.append(
f"{key}: <function '{self.func_map[key].__module__}.{self.func_map[key].__name__}'>")
return "\n".join(descriptions)
def __repr__(self):
description = self._list()
description = '\n'.join(['\t' + s for s in description.split('\n')])
return f"{self.__class__.__name__} [{self.name}], \n" + description
from .registry import Registry, build_from_config
def build_func(cfg, registry, **kwargs):
"""
Except for config, if passing a list of dataset config, then return the concat type of it
"""
return build_from_config(cfg, registry, **kwargs)
AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func)
DATASETS = Registry("DATASETS", build_func=build_func)
DIFFUSION = Registry("DIFFUSION", build_func=build_func)
DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func)
EMBEDDER = Registry("EMBEDDER", build_func=build_func)
ENGINE = Registry("ENGINE", build_func=build_func)
INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func)
MODEL = Registry("MODEL", build_func=build_func)
PRETRAIN = Registry("PRETRAIN", build_func=build_func)
VISUAL = Registry("VISUAL", build_func=build_func)
import torch
import random
import numpy as np
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
\ No newline at end of file
import torch
import torchvision.transforms.functional as F
import random
import math
import numpy as np
from PIL import Image, ImageFilter
__all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
else:
rgb = rgb.resize(self.size, Image.BILINEAR)
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
# self.min_area = min_area
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
scale = self.size_short / min(rgb[0].size)
rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
out_w = self.size
out_h = self.size
w, h = rgb[0].size # (518, 292)
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
# rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
# # center crop
# x1 = (img[0].width - self.size) // 2
# y1 = (img[0].height - self.size) // 2
# img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
# self.min_area = min_area
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
scale = self.size_short / min(rgb[0].size)
rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
out_w = self.size
out_h = self.size
w, h = rgb[0].size # (518, 292)
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
# self.min_area = min_area
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb]
scale = self.size_short / min(rgb[0].size)
rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img]
scale = self.size / min(img[0].size)
img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class CenterCropWide(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
if isinstance(img, list):
scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1])
img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=Image.BOX) for u in img]
# center crop
x1 = (img[0].width - self.size[0]) // 2
y1 = (img[0].height - self.size[1]) // 2
img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img]
return img
else:
scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1])
img = img.resize((round(img.width // scale), round(img.height // scale)), resample=Image.BOX)
x1 = (img.width - self.size[0]) // 2
y1 = (img.height - self.size[1]) // 2
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb]
return rgb
class ColorJitter(object):
def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(
max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
else:
rgb = F.to_tensor(rgb)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
if rgb.dim() == 4:
rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1))
elif rgb.dim() == 3:
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
return rgb
import torch
def to_device(batch, device, non_blocking=False):
if isinstance(batch, (list, tuple)):
return type(batch)([
to_device(u, device, non_blocking)
for u in batch])
elif isinstance(batch, dict):
return type(batch)([
(k, to_device(v, device, non_blocking))
for k, v in batch.items()])
elif isinstance(batch, torch.Tensor) and batch.device != device:
batch = batch.to(device, non_blocking=non_blocking)
else:
return batch
return batch
import os
import os.path as osp
import sys
import cv2
import glob
import math
import torch
import gzip
import copy
import time
import json
import pickle
import base64
import imageio
import hashlib
import requests
import binascii
import zipfile
# import skvideo.io
import numpy as np
from io import BytesIO
import urllib.request
import torch.nn.functional as F
import torchvision.utils as tvutils
from multiprocessing.pool import ThreadPool as Pool
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
def gen_text_image(captions, text_size):
num_char = int(38 * (text_size / text_size))
font_size = int(text_size / 20)
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size)
text_image_list = []
for text in captions:
txt_img = Image.new("RGB", (text_size, text_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
text_image_list.append(txt_img)
text_images = np.stack(text_image_list, axis=0)
text_images = torch.from_numpy(text_images)
return text_images
@torch.no_grad()
def save_video_refimg_and_text(
local_path,
ref_frame,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
nrow=4,
save_fps=8,
retry=5):
'''
gen_video: BxCxFxHxW
'''
nrow = max(int(gen_video.size(0) / 2), 1)
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3
text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3
ref_frame = ref_frame.unsqueeze(2)
ref_frame = ref_frame.mul_(vid_std).add_(vid_mean)
ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3
ref_frame.clamp_(0, 1)
ref_frame = ref_frame * 255.0
ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c')
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = torch.cat([ref_frame, images, text_images], dim=3)
images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow)
images = [(img.numpy()).astype('uint8') for img in images]
for _ in [None] * retry:
try:
if len(images) == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
local_path = local_path + '.mp4'
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
# os.system(f'rm -rf {local_path}')
exception = None
break
except Exception as e:
exception = e
continue
@torch.no_grad()
def save_i2vgen_video(
local_path,
image_id,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save both the generated video and the input conditions.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3
text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3
text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3
image_id = image_id.unsqueeze(2) # B, C, F, H, W
image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448
image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448
image_id.clamp_(0, 1)
image_id = image_id * 255.0
image_id = rearrange(image_id, 'b c f h w -> b f h w c')
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = torch.cat([image_id, images, text_images], dim=3)
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
exception = None
for _ in [None] * retry:
try:
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
@torch.no_grad()
def save_i2vgen_video_safe(
local_path,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
num_image = len(images)
exception = None
for _ in [None] * retry:
try:
if num_image == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8)
for fid, frame in enumerate(images):
if fid == num_image-1: # Fix known bugs.
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
if ratio > 0.4: continue
writer.append_data(frame)
writer.close()
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
@torch.no_grad()
def save_t2vhigen_video_safe(
local_path,
gen_video,
captions,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
text_size=256,
retry=5,
save_fps = 8
):
'''
Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
'''
vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384
gen_video.clamp_(0, 1)
gen_video = gen_video * 255.0
images = rearrange(gen_video, 'b c f h w -> b f h w c')
images = images[0]
images = [(img.numpy()).astype('uint8') for img in images]
num_image = len(images)
exception = None
for _ in [None] * retry:
try:
if num_image == 1:
local_path = local_path + '.png'
cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
for fid, frame in enumerate(images):
if fid == num_image-1: # Fix known bugs.
ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
if ratio > 0.4: continue
tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd)
os.system(f'rm -rf {frame_dir}')
break
except Exception as e:
exception = e
continue
if exception is not None:
raise exception
\ No newline at end of file
File added
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment