Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
import argparse
import datetime
import json
import os
import warnings
import omegaconf
import torch
import torch.distributed
from omegaconf import OmegaConf
from sat import mpu
from sat.arguments import (add_data_args, add_evaluation_args,
add_training_args, set_random_seed)
from sat.helpers import print_rank0
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--base',
type=str,
nargs='*',
help='config for input and saving')
group.add_argument('--second',
type=str,
nargs='*',
help='config for input and saving')
group.add_argument(
'--model-parallel-size',
type=int,
default=1,
help='size of the model parallel. only use if you are an expert.')
group.add_argument('--force-pretrain', action='store_true')
group.add_argument('--device', type=int, default=-1)
group.add_argument('--debug', action='store_true')
group.add_argument('--log-image', type=bool, default=True)
group.add_argument('--inf-ckpt', type=str, default=None)
group.add_argument('--inf-ckpt2', type=str, default=None)
group.add_argument('--skip-second', action='store_true')
group.add_argument('--first-stage-re',
type=int,
default=270,
help='resolution of first stage')
return parser
def add_sampling_config_args(parser):
"""Sampling configurations"""
group = parser.add_argument_group('sampling', 'Sampling Configurations')
group.add_argument('--output-dir', type=str, default='samples')
group.add_argument('--input-dir', type=str, default=None)
group.add_argument('--input-type', type=str, default='cli')
group.add_argument('--input-file', type=str, default='./ht100.txt')
group.add_argument('--final-size', type=int, default=2048)
group.add_argument('--sdedit', action='store_true')
group.add_argument('--grid-num-rows', type=int, default=1)
group.add_argument('--force-inference', action='store_true')
group.add_argument('--lcm_steps', type=int, default=None)
group.add_argument('--sampling-num-frames', type=int, default=32)
group.add_argument('--sampling-num-steps', type=int, default=30)
group.add_argument('--sampling-fps', type=int, default=8)
group.add_argument('--only-save-latents', type=bool, default=False)
group.add_argument('--only-log-video-latents', type=bool, default=False)
group.add_argument('--latent-channels', type=int, default=32)
group.add_argument('--image2video', action='store_true')
return parser
def add_extra_config_args(parser):
group = parser.add_argument_group('joint', 'joint training Configurations')
group.add_argument('--img-iter', type=int, default=0)
group.add_argument('--video-iter', type=int, default=0)
return parser
def get_args(args_list=None, parser=None):
"""Parse all the args."""
if parser is None:
parser = argparse.ArgumentParser(description='sat')
else:
assert isinstance(parser, argparse.ArgumentParser)
parser = add_model_config_args(parser)
parser = add_sampling_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_data_args(parser)
parser = add_extra_config_args(parser)
import deepspeed
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(args_list)
args = process_config_to_args(args)
if not args.train_data:
print_rank0('No training data specified', level='WARNING')
assert (args.train_iters is None) or (
args.epochs is
None), 'only one of train_iters and epochs should be set.'
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0(
'No train_iters (recommended) or epochs specified, use default 10k iters.',
level='WARNING')
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv('WORLD_SIZE', '1'))
if args.local_rank is None:
args.local_rank = int(os.getenv('LOCAL_RANK', '0')) # torchrun
if args.device == -1:
if torch.cuda.device_count() == 0:
args.device = 'cpu'
elif args.local_rank is not None:
args.device = args.local_rank
else:
args.device = args.rank % torch.cuda.device_count()
if args.local_rank != args.device and args.mode != 'inference':
raise ValueError(
'LOCAL_RANK (default 0) and args.device inconsistent. '
'This can only happens in inference mode. '
'Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. ')
if args.rank == 0:
print_rank0(f'using world size: {args.world_size}')
if args.train_data_weights is not None:
assert len(args.train_data_weights) == len(args.train_data)
if args.mode != 'inference': # training with deepspeed
args.deepspeed = True
if args.deepspeed_config is None: # not specified
deepspeed_config_path = os.path.join(
os.path.dirname(__file__), 'training',
f'deepspeed_zero{args.zero_stage}.json')
with open(deepspeed_config_path) as file:
args.deepspeed_config = json.load(file)
override_deepspeed_config = True
else:
override_deepspeed_config = False
assert not (args.fp16 and args.bf16), 'cannot specify both fp16 and bf16.'
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
print_rank0('Automatically set fp16=True to use ZeRO.')
args.fp16 = True
args.bf16 = False
if args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
else:
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
deepspeed_config = args.deepspeed_config
if override_deepspeed_config: # not specify deepspeed_config, use args
if args.fp16:
deepspeed_config['fp16']['enabled'] = True
elif args.bf16:
deepspeed_config['bf16']['enabled'] = True
deepspeed_config['fp16']['enabled'] = False
else:
deepspeed_config['fp16']['enabled'] = False
deepspeed_config[
'train_micro_batch_size_per_gpu'] = args.batch_size
deepspeed_config[
'gradient_accumulation_steps'] = args.gradient_accumulation_steps
optimizer_params_config = deepspeed_config['optimizer']['params']
optimizer_params_config['lr'] = args.lr
optimizer_params_config['weight_decay'] = args.weight_decay
else: # override args with values in deepspeed_config
if args.rank == 0:
print_rank0(
'Will override arguments with manually specified deepspeed_config!'
)
if 'fp16' in deepspeed_config and deepspeed_config['fp16'][
'enabled']:
args.fp16 = True
else:
args.fp16 = False
if 'bf16' in deepspeed_config and deepspeed_config['bf16'][
'enabled']:
args.bf16 = True
else:
args.bf16 = False
if 'train_micro_batch_size_per_gpu' in deepspeed_config:
args.batch_size = deepspeed_config[
'train_micro_batch_size_per_gpu']
if 'gradient_accumulation_steps' in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config[
'gradient_accumulation_steps']
else:
args.gradient_accumulation_steps = None
if 'optimizer' in deepspeed_config:
optimizer_params_config = deepspeed_config['optimizer'].get(
'params', {})
args.lr = optimizer_params_config.get('lr', args.lr)
args.weight_decay = optimizer_params_config.get(
'weight_decay', args.weight_decay)
args.deepspeed_config = deepspeed_config
# initialize distributed and random seed because it always seems to be necessary.
initialize_distributed(args)
args.seed = args.seed + mpu.get_data_parallel_rank()
set_random_seed(args.seed)
return args
def initialize_distributed(args):
"""Initialize torch.distributed."""
if torch.distributed.is_initialized():
if mpu.model_parallel_is_initialized():
if args.model_parallel_size != mpu.get_model_parallel_world_size():
raise ValueError(
'model_parallel_size is inconsistent with prior configuration.'
'We currently do not support changing model_parallel_size.'
)
return False
else:
if args.model_parallel_size > 1:
warnings.warn(
'model_parallel_size > 1 but torch.distributed is not initialized via SAT.'
'Please carefully make sure the correctness on your own.')
mpu.initialize_model_parallel(args.model_parallel_size)
return True
# the automatic assignment of devices has been moved to arguments.py
if args.device == 'cpu':
pass
else:
torch.cuda.set_device(args.device)
# Call the init process
init_method = 'tcp://'
args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
if args.world_size == 1:
from sat.helpers import get_free_port
default_master_port = str(get_free_port())
else:
default_master_port = '6000'
args.master_port = os.getenv('MASTER_PORT', default_master_port)
init_method = None
#init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
timeout=datetime.timedelta(seconds=1200))
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
# Set vae context parallel group equal to model parallel group
from sgm.util import (initialize_context_parallel,
set_context_parallel_group)
if args.model_parallel_size <= 2:
set_context_parallel_group(args.model_parallel_size,
mpu.get_model_parallel_group())
else:
initialize_context_parallel(2)
# mpu.initialize_model_parallel(1)
# Optional DeepSpeed Activation Checkpointing Features
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(dist_backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
else:
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import (
_CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
1) # default seed 1
except Exception as e:
from sat.helpers import print_rank0
print_rank0(str(e), level='DEBUG')
return True
def process_config_to_args(args):
"""Fetch args from only --base"""
configs = [OmegaConf.load(cfg) for cfg in args.base]
config = OmegaConf.merge(*configs)
args_config = config.pop('args', OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
args_config[key], omegaconf.ListConfig):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]
if hasattr(args, key):
setattr(args, key, arg)
if 'model' in config:
model_config = config.pop('model', OmegaConf.create())
args.model_config = model_config
if 'deepspeed' in config:
deepspeed_config = config.pop('deepspeed', OmegaConf.create())
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
if 'data' in config:
data_config = config.pop('data', OmegaConf.create())
args.data_config = data_config
if 'img_data' in config:
img_data_config = config.pop('img_data', OmegaConf.create())
args.img_data_config = img_data_config
if 'video_data' in config:
video_data_config = config.pop('video_data', OmegaConf.create())
args.video_data_config = video_data_config
if 'trainable_params' in config:
trainable_params = config.pop('trainable_params', OmegaConf.create())
args.trainable_params_config = trainable_params
if 'share_cache_args' in config:
share_cache_args = config.pop('share_cache_args', OmegaConf.create())
args.share_cache_config = share_cache_args
if 'custom_args' in config:
custom_args = config.pop('custom_args', OmegaConf.create())
for k, v in custom_args.items():
setattr(args, k, v)
return args
'''
@File : base_model.py
@Time : 2021/10/01 22:40:33
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
import argparse
import inspect
import math
import os
import random
import sys
import warnings
# here put the import lib
from functools import partial
import torch
from base_transformer import GCBaseTransformer
from sat.arguments import (overwrite_args_by_dict, reset_random_seed,
set_random_seed, update_args_with_file)
from sat.helpers import print_rank0
from sat.model.mixins import BaseMixin
from sat.model.registry import MetaModel, model_registry
from sat.model.transformer import standard_attention
from sat.mpu.initialize import (destroy_model_parallel,
get_model_parallel_rank, get_node_rank,
initialize_model_parallel)
from sat.mpu.operation import (mp_merge_model_rank0, mp_merge_model_send,
mp_split_model_rank0, mp_split_model_receive)
from sat.resources import auto_create
from sat.training.model_io import load_checkpoint
from sat.transformer_defaults import ARGS_DEFAULT, HOOKS_DEFAULT
class BaseModel(torch.nn.Module, metaclass=MetaModel):
def __init__(self,
args,
transformer=None,
params_dtype=torch.float,
**kwargs):
super().__init__()
self.mixins = torch.nn.ModuleDict()
self.collect_hooks_()
if transformer is not None:
self.transformer = transformer
else:
# check if model-only mode
from sat.arguments import _simple_init
success = _simple_init(
model_parallel_size=args.model_parallel_size,
seed=args.seed if hasattr(args, 'seed') else 1234)
args_dict = {
k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1])
for k, v in ARGS_DEFAULT.items()
}
self.transformer = GCBaseTransformer(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
max_sequence_length=args.max_sequence_length,
layernorm_order=args.layernorm_order,
**args_dict,
hooks=self.hooks,
params_dtype=params_dtype,
skip_init=args.skip_init,
device=torch.cuda.current_device()
if hasattr(args, 'use_gpu_initialization')
and args.use_gpu_initialization else torch.device('cpu'),
**kwargs)
def reinit(self,
mixin_names=None
): # will be called when loading model, None means all
# if some mixins are loaded, overrides this function
for k, m in self.mixins.items():
if mixin_names is None or k in mixin_names:
m.reinit(self)
def add_mixin(self, name, new_mixin, reinit=False):
assert name not in self.mixins
assert isinstance(new_mixin, BaseMixin)
self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer',
self.transformer) # cannot use pytorch set_attr
self.collect_hooks_()
if reinit:
new_mixin.reinit(self) # also pass current mixins
def del_mixin(self, name):
assert name in self.mixins
del self.mixins[name]
self.collect_hooks_()
def get_mixin(self, name):
return self.mixins[name]
def forward(self, *args, **kwargs):
# update hooks as the current model (overrided forwards)
# Attention! the transformer might be shared by multiple models
self.transformer.hooks.clear()
self.transformer.hooks.update(self.hooks)
return self.transformer(*args, **kwargs)
def collect_hooks_(self):
names = list(HOOKS_DEFAULT.keys())
hooks = {}
hook_origins = {}
for name in names:
if hasattr(self, name):
hooks[name] = getattr(self, name)
hook_origins[name] = 'model'
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if hasattr(getattr(m, name), 'non_conflict'):
# check getattr(m, name), who must accept old_impl as an argument
signature = inspect.signature(getattr(m, name))
if 'old_impl' not in signature.parameters:
raise ValueError(
f'Hook {name} at {mixin_name} must accept old_impl as an argument.'
)
# -------------
if name in hooks:
old_impl = hooks[name]
elif name == 'attention_fn': # the only hook without self
old_impl = HOOKS_DEFAULT[name]
else:
old_impl = partial(
HOOKS_DEFAULT[name], self
) # relax! `partial` does not affect the signature
old_origin = hook_origins.get(name, 'default')
hooks[name] = partial(getattr(m, name),
old_impl=old_impl)
hook_origins[name] = mixin_name + ' -> ' + old_origin
elif name in hooks and not hasattr(
hooks[name], 'replacable'
): # if this hook name is already registered
raise ValueError(
f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.'
)
else: # new hook
if name in hooks and hasattr(hooks[name],
'replacable'):
warnings.warn(
f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.'
)
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
self.hooks = hooks
self.hook_origins = hook_origins
return hooks
def disable_untrainable_params(self):
pass
@classmethod
def add_model_specific_args(cls, parser):
# recorded in arguments.py: add_model_config_args
return parser
@classmethod
def from_pretrained_base(cls,
name,
args=None,
*,
home_path=None,
url=None,
prefix='',
build_only=False,
overwrite_args={},
**kwargs):
'''Load a pretrained checkpoint of the current model.
Args:
name: The identifier of the pretrained model.
args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
path: the parent folder of existing `name` model. Default: SAT_HOME.
url: the url of the model. Default: SAT_URL.
prefix: the prefix of the checkpoint. Default: ''.
Returns:
model: the loaded model.
args: the loaded args.
'''
if os.path.exists(name) and os.path.isdir(name):
model_path = name
else:
model_path = auto_create(name, path=home_path, url=url)
# create a new args if not provided
if args is None:
args = cls.get_args()
args = update_args_with_file(args,
path=os.path.join(model_path,
'model_config.json'))
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
specific_iteration = kwargs.pop('specific_iteration', None)
model = get_model(args, cls, **kwargs)
if not build_only:
load_checkpoint(model,
args,
load_path=model_path,
prefix=prefix,
specific_iteration=specific_iteration)
return model, args
@classmethod
def from_pretrained(cls,
name,
args=None,
*,
home_path=None,
url=None,
prefix='',
build_only=False,
use_node_group=True,
overwrite_args={},
**kwargs):
if build_only or 'model_parallel_size' not in overwrite_args:
return cls.from_pretrained_base(name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=build_only,
overwrite_args=overwrite_args,
**kwargs)
else:
new_model_parallel_size = overwrite_args['model_parallel_size']
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
model, model_args = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=True,
overwrite_args=overwrite_args,
**kwargs)
local_rank = get_node_rank(
) if use_node_group else get_model_parallel_rank()
world_size = torch.distributed.get_world_size()
assert world_size % new_model_parallel_size == 0, 'world size should be a multiplier of new model_parallel_size.'
destroy_model_parallel()
initialize_model_parallel(1)
if local_rank == 0:
args.skip_init = True
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args.pop('model_parallel_size')
model_full, args_ = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=False,
overwrite_args=overwrite_args,
**kwargs)
if args_.model_parallel_size != 1:
raise Exception(
"We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!"
)
if hasattr(
args, 'mode'
) and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(new_model_parallel_size)
if local_rank == 0:
mp_split_model_rank0(model,
model_full,
use_node_group=use_node_group)
del model_full
else:
mp_split_model_receive(model,
use_node_group=use_node_group)
reset_random_seed(6)
else:
overwrite_args.pop('model_parallel_size')
model, model_args = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=False,
overwrite_args=overwrite_args,
**kwargs)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert world_size == model_args.model_parallel_size, 'world size should be equal to model_parallel_size.'
destroy_model_parallel()
initialize_model_parallel(1)
if rank == 0:
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args['model_parallel_size'] = 1
model_full, args_ = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=True,
overwrite_args=overwrite_args,
**kwargs)
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(model_args.model_parallel_size)
if rank == 0:
mp_merge_model_rank0(model, model_full)
model, model_args = model_full, args_
else:
mp_merge_model_send(model)
model_args.model_parallel_size = 1
destroy_model_parallel()
initialize_model_parallel(1)
return model, model_args
@classmethod
def list_avail_args(cls, print=True):
'''List all available args of the current model.'''
parser = argparse.ArgumentParser()
from sat.arguments import add_model_config_args
add_model_config_args(parser)
# add args of the current model
if hasattr(cls, 'add_model_specific_args'):
cls.add_model_specific_args(parser)
if print:
from sat.helpers import print_parser
print_parser(parser)
return parser
@classmethod
def get_args(cls, **kwargs):
'''Get the parsed args of the current model.
Args:
**kwargs: will override the default args.
Returns:
args: the parsed args.
'''
parser = cls.list_avail_args(print=False)
# use parser to parse kwargs
args = parser.parse_args([])
for k, v in kwargs.items():
if hasattr(args, k) or k in [
'fp16'
]: # non-arch args but affect building models
setattr(args, k, v)
else:
print_rank0(
f'warning: Unknown arg {k} for class {cls.__name__}.',
level='DEBUG')
setattr(args, k, v)
return args
# rewritten, Copyright (c) 2021, Ming Ding. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import copy
import math
import torch
import torch.nn.functional as F
from deepspeed.runtime.activation_checkpointing.checkpointing import \
non_reentrant_checkpoint as checkpoint
from sat import mpu
from sat.model.transformer import BaseTransformerLayer
from sat.mpu import (ColumnParallelLinear, RowParallelLinear,
VocabParallelEmbedding, copy_to_model_parallel_region,
gather_from_model_parallel_region,
get_model_parallel_world_size)
from sat.mpu.utils import (divide, gelu, scaled_init_method, sqrt,
unscaled_init_method)
from sat.ops.layernorm import LayerNorm
from sat.transformer_defaults import (HOOKS_DEFAULT,
split_tensor_along_last_dim,
standard_attention)
# checkpoint
class GCBaseTransformer(torch.nn.Module):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length,
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path=0,
checkpoint_activations=False,
checkpoint_num_layers=1,
checkpoint_skip_layers=0,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
layernorm_order='pre',
parallel_output=False,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
activation_func=gelu,
is_gated_mlp=False,
is_rotary_emb=False,
num_experts=1,
layernorm=LayerNorm,
init_method=None,
use_final_layernorm=True,
hooks={},
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')):
super().__init__()
# recording parameters
self.hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
self.is_decoder = is_decoder
self.cross_attn_hidden_size = cross_attn_hidden_size
self.cross_num_multi_query_heads = cross_num_multi_query_heads
if not is_decoder and cross_attn_hidden_size is not None:
print(
'warning: cross_attn_hidden_size is set but is_decoder is False'
)
self.use_bias = use_bias
self.use_qkv_bias = use_qkv_bias
self.num_multi_query_heads = num_multi_query_heads
self.is_gated_mlp = is_gated_mlp
self.is_rotary_emb = is_rotary_emb
self.num_experts = num_experts
self.use_final_layernorm = use_final_layernorm
self.layernorm_epsilon = layernorm_epsilon
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.checkpoint_skip_layers = checkpoint_skip_layers
assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
self.max_sequence_length = max_sequence_length
self.layernorm_order = layernorm_order
self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
self.hooks = copy.copy(hooks) # hooks will be updated each forward
object.__setattr__(
self, 'transformer',
self) # to give the default hooks the same api as outer hooks
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if vocab_size < 1000:
self.word_embeddings = torch.nn.Embedding(vocab_size,
hidden_size,
dtype=params_dtype,
device=device)
torch.nn.init.normal_(self.word_embeddings.weight,
mean=0.0,
std=init_method_std)
else:
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
params_dtype=params_dtype,
skip_init=skip_init,
device=device)
if self.is_rotary_emb:
from sat.model.position_embedding.triton_rotary_embeddings import \
FastRotaryEmbedding
self.position_embeddings = FastRotaryEmbedding(hidden_size //
num_attention_heads)
else:
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight,
mean=0.0,
std=init_method_std)
# create all layers
if init_method is None:
self.output_layer_init_method = scaled_init_method(
init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
cross_hidden_size_per_attention_head=
cross_hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
cross_attn_hidden_size=cross_attn_hidden_size,
layernorm_order=layernorm_order,
layernorm=layernorm,
use_bias=use_bias,
use_qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
drop_path=drop_path,
activation_func=activation_func,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
hooks=self.hooks,
transformer_pointer=self,
params_dtype=params_dtype,
skip_init=skip_init,
device=device)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
if use_final_layernorm:
self.final_layernorm = layernorm(hidden_size,
eps=layernorm_epsilon)
def forward(self,
input_ids,
position_ids,
attention_mask,
*,
output_hidden_states=False,
**kw_args):
# sanity check
assert len(input_ids.shape) >= 2
batch_size, query_length = input_ids.shape[:2]
if attention_mask is None:
# Definition: None means full attention
attention_mask = torch.ones(1, 1, device=input_ids.device)
elif isinstance(attention_mask, int) and (attention_mask < 0):
# Definition: -1 means lower triangular attention mask
attention_mask = torch.ones(query_length,
query_length,
device=input_ids.device).tril()
attention_mask = attention_mask.type_as(next(self.parameters()))
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# initial output_cross_layer might be generated by word/position_embedding_forward
output_cross_layer = {}
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](
input_ids, output_cross_layer=output_cross_layer, **kw_args)
else: # default
hidden_states = HOOKS_DEFAULT['word_embedding_forward'](
self,
input_ids,
output_cross_layer=output_cross_layer,
**kw_args)
# handle position embedding
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](
position_ids, output_cross_layer=output_cross_layer, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == hidden_states.shape[1], (
position_ids.shape, hidden_states.shape)
position_embeddings = HOOKS_DEFAULT['position_embedding_forward'](
self,
position_ids,
output_cross_layer=output_cross_layer,
**kw_args)
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
output_per_layers = []
if self.checkpoint_activations:
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_,
mask,
layer_id=layer.layer_id,
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
else:
layer_ret = layer(
x_,
mask,
layer_id=layer.layer_id,
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return (x_, output_per_layers_part, output_cross_layer,
*flat_outputs)
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
if l + self.checkpoint_skip_layers >= num_layers:
# no checkpointing
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs)
else:
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[
output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](
*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
else:
layer_ret = layer(
*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
if self.use_final_layernorm:
logits = self.final_layernorm(hidden_states)
else:
logits = hidden_states
logits = copy_to_model_parallel_region(logits)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](
logits, **kw_args, parallel_output=self.parallel_output)
else:
logits_parallel = HOOKS_DEFAULT['final_forward'](
self, logits, **kw_args, parallel_output=self.parallel_output)
outputs = [logits_parallel]
outputs.extend(output_per_layers)
return outputs
share_cache_args:
disable_ref : True
num_vis_img: 4
vis_ddpm: True
eval_interval_list: [1, 50, 100, 1000]
save_interval_list: [2000]
args:
checkpoint_activations: False # using gradient checkpointing
model_parallel_size: 1
experiment_name: lora-disney
mode: finetune
load: ""
no_load_rng: True
train_iters: 10000000000 # Suggest more than 1000 For Lora and SFT For 500 is enough
eval_iters: 100000000
eval_interval: 1000000000000
eval_batch_size: 1
save: ./
save_interval: 1000
log_interval: 20
train_data: [ "disney" ] # Train data path
valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
split: 1,0,0
num_workers: 2
force_train: True
only_log_video_latents: True
deepspeed:
# Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 1
steps_per_print: 50
gradient_clipping: 0.1
zero_optimization:
stage: 2
cpu_offload: true
contiguous_gradients: false
overlap_comm: true
reduce_scatter: true
reduce_bucket_size: 1000000000
allgather_bucket_size: 1000000000
load_from_fp32_weights: false
zero_allow_untested_optimizer: true
bf16:
enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
loss_scale: 0
loss_scale_window: 400
hysteresis: 2
min_loss_scale: 1
model:
scale_factor: 0.7
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
network_config:
target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 64
time_compressed_rate: 4
# latent_width: 90
# latent_height: 60
latent_width: 512
latent_height: 512
num_layers: 42 # different from cogvideox_2b_infer.yaml
patch_size: 2
in_channels: 16
out_channels: 16
hidden_size: 3072 # different from cogvideox_2b_infer.yaml
adm_in_channels: 256
num_attention_heads: 48 # different from cogvideox_2b_infer.yaml
transformer_args:
checkpoint_activations: False
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
params:
hidden_size_head: 64
text_length: 226
lora_config:
target: extra_models.dit_res_adapter.ResLoraMixin
params:
r: 128
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "checkpoints/3d-vae.pt"
ignore_keys: [ 'loss' ]
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: False
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
num_steps: 51
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
# TODO check this cfg
scale: 8
exp: 5
num_steps: 51
custom_args:
reload: ""
share_cache_args:
sample_ref_noise_step: 675
time_size_embedding: True
args:
checkpoint_activations: True # using gradient checkpointing
model_parallel_size: 1
experiment_name: lora-disney
mode: finetune
load: "" # This is for Full model without lora adapter " # This is for Full model without lora adapter
no_load_rng: True
train_iters: 100000 # Suggest more than 1000 For Lora and SFT For 500 is enough
eval_iters: 100000000
eval_interval: [1, 200]
eval_batch_size: 1
save:
# for debug
save_interval: 250
log_interval: 5
train_data: [ "disney" ] # Train data path
valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
split: 1,0,0
num_workers: 1
force_train: True
only_log_video_latents: True
deepspeed:
# Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 1
steps_per_print: 50
gradient_clipping: 0.1
zero_optimization:
stage: 2
cpu_offload: true
contiguous_gradients: false
overlap_comm: true
reduce_scatter: true
reduce_bucket_size: 1000000000
allgather_bucket_size: 1000000000
load_from_fp32_weights: false
zero_allow_untested_optimizer: true
bf16:
enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
loss_scale: 0
loss_scale_window: 400
hysteresis: 2
min_loss_scale: 1
model:
scale_factor: 1.15258426
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 3.0 # different from cogvideox_2b_infer.yaml
network_config:
target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 64
time_compressed_rate: 4
# latent_width: 90
# latent_height: 60
latent_width: 512
latent_height: 512
num_layers: 30 # different from cogvideox_2b_infer.yaml
patch_size: 2
in_channels: 16
out_channels: 16
hidden_size: 1920 # different from cogvideox_2b_infer.yaml
adm_in_channels: 256
num_attention_heads: 30 # different from cogvideox_2b_infer.yaml
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
params:
hidden_size_head: 64
text_length: 226
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "checkpoints/3d-vae.pt"
ignore_keys: [ 'loss' ]
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: False
loss_fn_config:
target: flow_video.FlowVideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: False
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
sampler_config:
target: sgm.modules.diffusionmodules.sampling.CascadeVPSDEDPMPP2MSampler
params:
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Stage-I ckpt path \n",
"stage_1_path = \"./checkpoints/stage1.pt\"\n",
"stage_2_path = \"./checkpoints/stage2.pt\"\n",
"save_dir=\"vis_270p_1080p\"\n",
"# 2 ~ 3\n",
"shift_t = 2.5\n",
"# 4 ~ 6\n",
"sample_step = 5\n",
"# 10 ~ 13\n",
"cfg_second = 13\n",
" # 650 ~ 750\n",
"deg_latent_strength=675\n",
"# stage_1_hw \n",
"\n",
"#TODO Stage I CFG here\n",
"cfg_first = 8\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"current_directory = os.getcwd()\n",
"os.chdir(os.path.dirname(current_directory))\n",
"new_directory = os.getcwd()\n",
"print(f\"working directory: {new_directory}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import os\n",
"import argparse\n",
"import torch\n",
"import numpy as np\n",
"import copy\n",
"\n",
"from sat.model.base_model import get_model\n",
"from arguments import get_args\n",
"from torchvision.io.video import write_video\n",
"\n",
"from flow_video import FlowEngine\n",
"from diffusion_video import SATVideoDiffusionEngine\n",
"\n",
"import os\n",
"from utils import disable_all_init, decode, prepare_input, save_memory_encode_first_stage, save_mem_decode, seed_everything\n",
"disable_all_init()\n",
"\n",
"\n",
"def init_model(model, second_model, args, second_args):\n",
" share_cache = dict()\n",
" second_share_cache = dict()\n",
" if hasattr(args, 'share_cache_config'):\n",
" for k, v in args.share_cache_config.items():\n",
" share_cache[k] = v\n",
" if hasattr(second_args, 'share_cache_config'):\n",
" for k, v in second_args.share_cache_config.items():\n",
" second_share_cache[k] = v\n",
"\n",
" for n, m in model.named_modules():\n",
" m.share_cache = share_cache \n",
" if hasattr(m, \"register_new_modules\"):\n",
" m.register_new_modules()\n",
" for n, m in second_model.named_modules():\n",
" m.share_cache = second_share_cache \n",
" if hasattr(m, \"register_new_modules\"):\n",
" m.register_new_modules() \n",
"\n",
" weight_path = args.inf_ckpt\n",
" weight = torch.load(weight_path, map_location=\"cpu\")\n",
" if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in weight[\"module\"]:\n",
" del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
" del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
" msg = model.load_state_dict(weight[\"module\"], strict=False)\n",
" print(msg)\n",
" second_weight_path = args.inf_ckpt2\n",
" second_weight = torch.load(second_weight_path, map_location=\"cpu\")\n",
" if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in second_weight[\"module\"]:\n",
" del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
" del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
" second_msg = second_model.load_state_dict(second_weight[\"module\"], strict=False)\n",
" print(second_msg)\n",
"\n",
"def get_first_results(model, text, num_frames, H, W, neg_prompt=None):\n",
" \"\"\"Get first Stage results.\n",
"\n",
" Args:\n",
" model (nn.Module): first stage model.\n",
" text (str): text prompt\n",
" num_frames (int): number of frames\n",
" H (int): height of the first stage results\n",
" W (int): width of the first stage results\n",
" neg_prompt (str): negative prompt\n",
"\n",
" Returns:\n",
" Tensor: first stage video.\n",
" \"\"\"\n",
" device = 'cuda'\n",
" T = 1 + (num_frames - 1) // 4\n",
" F = 8\n",
" motion_text_prefix = [\n",
" 'very low motion,',\n",
" 'low motion,',\n",
" 'medium motion,',\n",
" 'high motion,',\n",
" 'very high motion,',\n",
" ]\n",
" pos_prompt = \"\"\n",
" if neg_prompt is None:\n",
" neg_prompt = \"\"\n",
" with torch.no_grad():\n",
" model.to('cuda')\n",
" input_negative_prompt = motion_text_prefix[\n",
" 0] + ', ' + motion_text_prefix[1] + neg_prompt\n",
" c, uc = prepare_input(text,\n",
" model,\n",
" T,\n",
" negative_prompt=input_negative_prompt,\n",
" pos_prompt=pos_prompt)\n",
" with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
" device_type='cuda',\n",
" dtype=torch.bfloat16):\n",
" samples_z = model.sample(\n",
" c,\n",
" uc=uc,\n",
" batch_size=1,\n",
" shape=(T, 16, H // F, W // F),\n",
" num_steps=model.share_cache.get('first_sample_step', None),\n",
" )\n",
" samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" model.to('cpu')\n",
" torch.cuda.empty_cache()\n",
" first_stage_model = model.first_stage_model\n",
" first_stage_model = first_stage_model.to(device)\n",
"\n",
" latent = 1.0 / model.scale_factor * samples_z\n",
"\n",
" samples = decode(first_stage_model, latent)\n",
" model.to('cpu')\n",
" return samples\n",
"def get_second_results(model, text, first_stage_samples, num_frames):\n",
" \"\"\"Get second Stage results.\n",
"\n",
" Args:\n",
" model (nn.Module): second stage model.\n",
" text (str): text prompt\n",
" first_stage_samples (Tensor): first stage results\n",
" num_frames (int): number of frames\n",
" Returns:\n",
" Tensor: second stage results.\n",
" \"\"\"\n",
"\n",
" t, h, w, c = first_stage_samples.shape\n",
" first_stage_samples = first_stage_samples[:num_frames]\n",
" first_stage_samples = (first_stage_samples / 255.)\n",
" first_stage_samples = (first_stage_samples - 0.5) / 0.5\n",
"\n",
" target_size = model.share_cache.get('target_size', None)\n",
" if target_size is None:\n",
" upscale_factor = model.share_cache.get('upscale_factor', 8)\n",
" H = int(h * upscale_factor) // 16 * 16\n",
" W = int(w * upscale_factor) // 16 * 16\n",
" else:\n",
" H, W = target_size\n",
" H = H // 16 * 16\n",
" W = W // 16 * 16\n",
"\n",
" first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')\n",
"\n",
" ref_x = torch.nn.functional.interpolate(first_stage_samples,\n",
" size=(H, W),\n",
" mode='bilinear',\n",
" align_corners=False,\n",
" antialias=True)\n",
" ref_x = ref_x[:num_frames][None]\n",
"\n",
" ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" first_stage_model = model.first_stage_model\n",
" print(f'start encoding first stage results to high resolution')\n",
" with torch.no_grad():\n",
" first_stage_dtype = next(model.first_stage_model.parameters()).dtype\n",
" model.first_stage_model.cuda()\n",
" ref_x = save_memory_encode_first_stage(\n",
" ref_x.contiguous().to(first_stage_dtype).cuda(), model)\n",
"\n",
" ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
" ref_x = ref_x.to(model.dtype)\n",
" print(f'finish encoding first stage results, and starting stage II')\n",
"\n",
" device = 'cuda'\n",
"\n",
" model.to(device)\n",
"\n",
" pos_prompt = ''\n",
" input_negative_prompt = \"\"\n",
"\n",
" c, uc = prepare_input(text,\n",
" model,\n",
" num_frames,\n",
" negative_prompt=input_negative_prompt,\n",
" pos_prompt=pos_prompt)\n",
"\n",
" T = 1 + (num_frames - 1) // 4\n",
" F = 8\n",
" with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
" device_type='cuda',\n",
" dtype=torch.bfloat16):\n",
" samples_z = model.sample(\n",
" ref_x,\n",
" c,\n",
" uc=uc,\n",
" batch_size=1,\n",
" shape=(T, 16, H // F, W // F),\n",
" num_steps=model.share_cache.get('sample_step', 5),\n",
" method='euler',\n",
" cfg=model.share_cache.get('cfg', 7.5),\n",
" )\n",
" samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" model.to('cpu')\n",
" torch.cuda.empty_cache()\n",
" first_stage_model = model.first_stage_model\n",
" first_stage_model = first_stage_model.to(device)\n",
"\n",
" latent = 1.0 / model.scale_factor * samples_z\n",
" print(f'start spatiotemporal slice decoding')\n",
" samples = save_mem_decode(first_stage_model, latent)\n",
" print(f'finish spatiotemporal slice decoding')\n",
" model.to('cpu')\n",
" return samples\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"\n",
"os.environ[\"LOCAL_RANK\"] = \"0\"\n",
"os.environ[\"WORLD_SIZE\"] = \"1\"\n",
"os.environ[\"RANK\"] = \"0\"\n",
"os.environ[\"MASTER_ADDR\"] = \"0.0.0.0\"\n",
"os.environ[\"MASTER_PORT\"] = \"12345\"\n",
"\n",
"py_parser = argparse.ArgumentParser(add_help=False)\n",
"args_list = [\n",
" \"--base\", \"flashvideo/configs/stage1.yaml\",\n",
" \"--second\", \"flashvideo/configs/stage2.yaml\",\n",
" \"--inf-ckpt\", stage_1_path,\n",
" \"--inf-ckpt2\", stage_2_path,\n",
"]\n",
"known, args_list = py_parser.parse_known_args(args=args_list)\n",
"second_args_list = copy.deepcopy(args_list)\n",
"\n",
"\n",
"args = get_args(args_list)\n",
"args = argparse.Namespace(**vars(args), **vars(known))\n",
"del args.deepspeed_config\n",
"args.model_config.first_stage_config.params.cp_size = 1\n",
"args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
"args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
"args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n",
"\n",
"second_args_list[1] = args.second[0]\n",
"second_args = get_args(second_args_list)\n",
"second_args = argparse.Namespace(**vars(second_args), **vars(known))\n",
"del second_args.deepspeed_config\n",
"second_args.model_config.first_stage_config.params.cp_size = 1\n",
"second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
"second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
"second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_cls=SATVideoDiffusionEngine\n",
"second_model_cls=FlowEngine\n",
"local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n",
"torch.cuda.set_device(local_rank)\n",
"\n",
"second_model = get_model(second_args, second_model_cls)\n",
"\n",
"model = get_model(args, model_cls)\n",
" \n",
"init_model(model, second_model, args, second_args )\n",
" \n",
"model.eval()\n",
"second_model.eval()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for n, m in model.named_modules():\n",
" if hasattr(m, \"merge_lora\"):\n",
" m.merge_lora()\n",
" print(f\"merge lora of {n}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_frames = 49\n",
"second_num_frames = 49 \n",
"\n",
"stage_1_hw = (270, 480) \n",
"stage_2_hw = (1080, 1920) \n",
"\n",
"# make sure all can be divided by 16\n",
"stage_1_hw = (stage_1_hw[0] // 16 * 16, stage_1_hw[1] // 16 * 16)\n",
"stage_2_hw = (stage_2_hw[0] // 16 * 16, stage_2_hw[1] // 16 * 16)\n",
"\n",
"sample_func = model.sample\n",
"T, H, W, C, F = num_frames, stage_1_hw[0], stage_1_hw[1], args.latent_channels, 8\n",
"S_T, S_H, S_W, S_C, S_F = second_num_frames, stage_2_hw[0], stage_2_hw[1], args.latent_channels, 8\n",
"\n",
"\n",
" \n",
"seed_everything(0)\n",
"\n",
"text = \" Sunny day, The camera smoothly pushes in through an ornate garden archway, delicately adorned with climbing ivy. \\\n",
" Beyond the archway, a secret, tranquil garden is revealed, brimming with a vibrant array of blooming flowers \\\n",
" in a myriad of colors. A beautiful young woman with long wavy brown hair, she is smile to the camera , \\\n",
" wearing a red hat sits holding a dog , the red hat has rich fabric texture \\\n",
" wearing black pleated skirt and yellow sweater \"\n",
"\n",
"\n",
"neg_text = \"\"\n",
"\n",
"if os.path.exists(save_dir) is False:\n",
" os.makedirs(save_dir)\n",
"enu_index = \"1\"\n",
"model.share_cache[\"cfg\"] = cfg_first\n",
"\n",
"first_stage_samples = get_first_results(model, text, num_frames, H, W, neg_text)\n",
"\n",
"print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}.mp4\")\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}.mp4', \n",
" fps=8, \n",
" video_array= first_stage_samples, \n",
" options = { 'crf': '14' })\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"second_num_frames = 49\n",
"second_model.share_cache[\"ref_noise_step\"] = deg_latent_strength\n",
"second_model.share_cache[\"sample_ref_noise_step\"] = deg_latent_strength\n",
"second_model.share_cache.pop(\"ref_noise_step_range\", None)\n",
"second_model.share_cache[\"target_size\"] = stage_2_hw\n",
"second_model.share_cache[\"shift_t\"] = shift_t\n",
"second_model.share_cache[\"sample_step\"] = sample_step\n",
"second_model.share_cache[\"cfg\"] = cfg_second\n",
"post_fix = f'''noise_{second_model.share_cache[\"ref_noise_step\"]}_step_{second_model.share_cache[\"sample_step\"]}_cfg_{second_model.share_cache[\"cfg\"]}_shift_{second_model.share_cache[\"shift_t\"]}_size_{stage_2_hw[0]}x{stage_2_hw[1]}'''\n",
"second_model.share_cache[\"time_size_embedding\"] = True\n",
"second_stage_samples = get_second_results(second_model, \n",
" text, \n",
" first_stage_samples, \n",
" second_num_frames)\n",
"\n",
"print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}.mp4\")\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_second.mp4', \n",
" fps=8, \n",
" video_array= second_stage_samples.cpu(), \n",
" options = { 'crf': '14' })\n",
"\n",
"\n",
"# save joint video \n",
"part_first_stage = first_stage_samples[:second_num_frames]\n",
"\n",
"target_h, target_w = second_stage_samples.shape[1], second_stage_samples.shape[2]\n",
"part_first_stage = torch.nn.functional.interpolate(part_first_stage.permute(0, 3, 1, 2).contiguous(),\n",
" size=(target_h, target_w),\n",
" mode=\"bilinear\",\n",
" align_corners=False, \n",
" antialias=True)\n",
"part_first_stage = part_first_stage.permute(0, 2, 3, 1).contiguous()\n",
"\n",
"\n",
"joint_video = torch.cat([part_first_stage.cpu(), second_stage_samples.cpu()], dim=-2)\n",
"print(f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4')\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4',\n",
" fps=8,\n",
" video_array=joint_video.cpu(),\n",
" options={'crf': '15'}) \n"
]
}
],
"metadata": {
"fileId": "c6eed2be-3101-492e-a984-783ecbc70a34",
"filePath": "/mnt/bn/foundation-ads/shilong/conda/code/cogvideo-5b/sat/demo_ab.ipynb",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import gc
import math
import os
import random
from typing import Any, Dict, List, Tuple, Union
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (default, disabled_train, get_obj_from_str,
instantiate_from_config, log_txt_as_img)
from torch import nn
from sat import mpu
from sat.helpers import print_rank0
from sat.model.finetune.lora2 import merge_linear_lora
class SATVideoDiffusionEngine(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
# model args preprocess
log_keys = model_config.get('log_keys', None)
input_key = model_config.get('input_key', 'mp4')
network_config = model_config.get('network_config', None)
network_wrapper = model_config.get('network_wrapper', None)
denoiser_config = model_config.get('denoiser_config', None)
sampler_config = model_config.get('sampler_config', None)
conditioner_config = model_config.get('conditioner_config', None)
first_stage_config = model_config.get('first_stage_config', None)
loss_fn_config = model_config.get('loss_fn_config', None)
scale_factor = model_config.get('scale_factor', 1.0)
latent_input = model_config.get('latent_input', False)
disable_first_stage_autocast = model_config.get(
'disable_first_stage_autocast', False)
no_cond_log = model_config.get('disable_first_stage_autocast', False)
not_trainable_prefixes = model_config.get(
'not_trainable_prefixes', ['first_stage_model', 'conditioner'])
compile_model = model_config.get('compile_model', False)
en_and_decode_n_samples_a_time = model_config.get(
'en_and_decode_n_samples_a_time', None)
lr_scale = model_config.get('lr_scale', None)
lora_train = model_config.get('lora_train', False)
self.use_pd = model_config.get('use_pd',
False) # progressive distillation
self.log_keys = log_keys
self.input_key = input_key
self.not_trainable_prefixes = not_trainable_prefixes
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.lora_train = lora_train
self.noised_image_input = model_config.get('noised_image_input', False)
self.noised_image_all_concat = model_config.get(
'noised_image_all_concat', False)
self.noised_image_dropout = model_config.get('noised_image_dropout',
0.0)
if args.fp16:
dtype = torch.float16
dtype_str = 'fp16'
elif args.bf16:
dtype = torch.bfloat16
dtype_str = 'bf16'
else:
dtype = torch.float32
dtype_str = 'fp32'
self.dtype = dtype
self.dtype_str = dtype_str
network_config['params']['dtype'] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(
default(network_wrapper,
OPENAIUNETWRAPPER))(model,
compile_model=compile_model,
dtype=dtype)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(
sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG))
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(
loss_fn_config) if loss_fn_config is not None else None
self.latent_input = latent_input
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device
def disable_untrainable_params(self):
total_trainable = 0
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == 'all':
flag = True
break
lora_prefix = ['matrix_A', 'matrix_B']
for prefix in lora_prefix:
if prefix in n:
flag = False
break
if flag:
p.requires_grad_(False)
else:
total_trainable += p.numel()
print_rank0('***** Total trainable parameters: ' +
str(total_trainable) + ' *****')
def reinit(self, parent_model=None):
# reload the initial params from previous trained modules
# you can also get access to other mixins through parent_model.get_mixin().
pass
def merge_lora(self):
for m in self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations:
m[1] = merge_linear_lora(m[1])
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def forward(self, x, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x,
batch)
loss_mean = loss.mean()
loss_dict = {'loss': loss_mean}
return loss_mean, loss_dict
def add_noise_to_first_frame(self, image):
sigma = torch.normal(mean=-3.0, std=0.5,
size=(image.shape[0], )).to(self.device)
sigma = torch.exp(sigma).to(image.dtype)
image_noise = torch.randn_like(image) * sigma[:, None, None, None,
None]
image = image + image_noise
return image
@torch.no_grad()
def save_memory_encode_first_stage(self, x, batch):
num_frames = x.shape[2]
splits_x = torch.split(x, [13, 12, 12, 12], dim=2)
all_out = []
with torch.autocast('cuda', enabled=False):
for idx, input_x in enumerate(splits_x):
if idx == len(splits_x) - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
out = self.first_stage_model.encode(
input_x.contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache)
all_out.append(out)
z = torch.cat(all_out, dim=2)
z = 1.15258426 * z
return z
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
# print(f"this is iteration {self.share_cache['iteration']}", flush=True)
# print(f'''{"train_size_range" in self.share_cache}''', flush=True)
if 'train_size_range' in self.share_cache:
train_size_range = self.share_cache.get('train_size_range')
size_factor = random.uniform(*train_size_range)
# broadcast the size factor from rank 0
size_factor = torch.tensor(size_factor).to(self.device)
torch.distributed.broadcast(size_factor,
src=0,
group=mpu.get_data_parallel_group())
# print(f"size_factor: {size_factor} at rank : {torch.distributed.get_rank()}", flush=True)
target_size = (int(x.shape[3] * size_factor),
int(x.shape[4] * size_factor))
# print(target_size)
# make sure it can be divided by 16
b, t, c, h, w = x.shape
# reshape to b * t, c, h, w
x = x.reshape(b * t, c, h, w)
target_size = (target_size[0] // 16 * 16,
target_size[1] // 16 * 16)
x = F.interpolate(x,
size=target_size,
mode='bilinear',
align_corners=False,
antialias=True)
# reshape back to b, t, c, h, w
x = x.reshape(b, t, c, target_size[0], target_size[1])
if self.lr_scale is not None:
lr_x = F.interpolate(x,
scale_factor=1 / self.lr_scale,
mode='bilinear',
align_corners=False)
lr_x = F.interpolate(lr_x,
scale_factor=self.lr_scale,
mode='bilinear',
align_corners=False)
lr_z = self.encode_first_stage(lr_x, batch)
batch['lr_input'] = lr_z
x = x.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_input:
image = x[:, :, 0:1]
image = self.add_noise_to_first_frame(image)
image = self.encode_first_stage(image, batch)
b, c, t, h, w = x.shape
if t == 49 and (h * w) > 480 * 720:
if os.environ.get('DEBUGINFO', None) is not None:
print(
f'save memory encode first stage with in shape {x.shape}, {x.mean()}'
)
x = self.save_memory_encode_first_stage(x, batch)
else:
x = self.encode_first_stage(x, batch)
# x = self.encode_first_stage(x, batch)
x = x.permute(0, 2, 1, 3, 4).contiguous()
if 'ref_mp4' in self.share_cache:
if not 'disable_ref' in self.share_cache:
ref_mp4 = self.share_cache.pop('ref_mp4')
ref_mp4 = ref_mp4.to(self.dtype).to(self.device)
ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous()
ref_x = self.encode_first_stage(ref_mp4, batch)
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
self.share_cache['ref_x'] = ref_x
if self.noised_image_input:
image = image.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_all_concat:
image = image.repeat(1, x.shape[1], 1, 1, 1)
else:
image = torch.concat([image, torch.zeros_like(x[:, 1:])],
dim=1)
if random.random() < self.noised_image_dropout:
image = torch.zeros_like(image)
batch['concat_images'] = image
# gc.collect()
# torch.cuda.empty_cache()
loss, loss_dict = self(x, batch)
return loss, loss_dict
def get_input(self, batch):
return batch[self.input_key].to(self.dtype)
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast('cuda',
enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {
'timesteps': len(z[n * n_samples:(n + 1) * n_samples])
}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples:(n + 1) * n_samples], **kwargs)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x, batch):
frame = x.shape[2]
if frame > 1 and self.latent_input:
x = x.permute(0, 2, 1, 3, 4).contiguous()
return x * self.scale_factor # already encoded
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast('cuda',
enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(x[n * n_samples:(n + 1) *
n_samples])
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
**kwargs,
):
randn = torch.randn(batch_size,
*shape).to(torch.float32).to(self.device)
if hasattr(self, 'seeded_noise'):
randn = self.seeded_noise(randn)
if prefix is not None:
randn = torch.cat([prefix, randn[:, prefix.shape[1]:]], dim=1)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
if mp_size > 1:
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(randn,
src=src,
group=mpu.get_model_parallel_group())
scale = None
scale_emb = None
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
self.model,
input,
sigma,
c,
concat_images=concat_images,
**addtional_model_inputs)
if 'cfg' in self.share_cache:
self.sampler.guider.scale = self.share_cache['cfg']
print('overwrite cfg scale in config of stage-1')
samples = self.sampler(denoiser,
randn,
cond,
uc=uc,
scale=scale,
scale_emb=scale_emb,
num_steps=kwargs.get('num_steps', None))
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[3:]
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or
(embedder.input_key
in self.log_keys)) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w),
x,
size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
'x'.join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w),
x,
size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
xc = log_txt_as_img((image_h, image_w),
x,
size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_video(
self,
batch: Dict,
N: int = 8,
ucg_keys: List[str] = None,
only_log_video_latents=False,
**kwargs,
) -> Dict:
conditioner_input_keys = [
e.input_key for e in self.conditioner.embedders
]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
'Each defined ucg key for sampling must be in the provided conditioner input keys,'
f'but we have {ucg_keys} vs. {conditioner_input_keys}')
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0 else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
if not self.latent_input:
log['inputs'] = x.to(torch.float32)
x = x.permute(0, 2, 1, 3, 4).contiguous()
z = self.encode_first_stage(x, batch)
if not only_log_video_latents:
log['reconstructions'] = self.decode_first_stage(z).to(
torch.float32)
log['reconstructions'] = log['reconstructions'].permute(
0, 2, 1, 3, 4).contiguous()
z = z.permute(0, 2, 1, 3, 4).contiguous()
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
if self.noised_image_input:
image = x[:, :, 0:1]
image = self.add_noise_to_first_frame(image)
image = self.encode_first_stage(image, batch)
image = image.permute(0, 2, 1, 3, 4).contiguous()
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
c['concat'] = image
uc['concat'] = image
samples = self.sample(c,
shape=z.shape[1:],
uc=uc,
batch_size=N,
**sampling_kwargs) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
log['latents'] = latents
else:
samples = self.decode_first_stage(samples).to(torch.float32)
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log['samples'] = samples
else:
samples = self.sample(c,
shape=z.shape[1:],
uc=uc,
batch_size=N,
**sampling_kwargs) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
log['latents'] = latents
else:
samples = self.decode_first_stage(samples).to(torch.float32)
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log['samples'] = samples
return log
class SATUpscalerEngine(SATVideoDiffusionEngine):
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x,
scale_factor=1 / self.lr_scale,
mode='bilinear',
align_corners=False)
lr_x = F.interpolate(lr_x,
scale_factor=self.lr_scale,
mode='bilinear',
align_corners=False)
lr_z = self.encode_first_stage(lr_x, batch)
batch['lr_input'] = lr_z
x = x.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_input:
image = x[:, :, 0:1]
image = self.add_noise_to_first_frame(image)
image = self.encode_first_stage(image, batch)
x = self.encode_first_stage(x, batch)
x = x.permute(0, 2, 1, 3, 4).contiguous()
if 'ref_mp4' in self.share_cache:
ref_mp4 = self.share_cache.pop('ref_mp4')
ref_mp4 = ref_mp4.to(self.dtype).to(self.device)
ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous()
ref_x = self.encode_first_stage(ref_mp4, batch)
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
self.share_cache['ref_x'] = ref_x
if self.noised_image_input:
image = image.permute(0, 2, 1, 3, 4).contiguous()
if self.noised_image_all_concat:
image = image.repeat(1, x.shape[1], 1, 1, 1)
else:
image = torch.concat([image, torch.zeros_like(x[:, 1:])],
dim=1)
if random.random() < self.noised_image_dropout:
image = torch.zeros_like(image)
batch['concat_images'] = image
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
ref_x = self.first_stage_model.decoder(ref_x)
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
loss_mean = torch.mean(((x - ref_x)**2).reshape(x.shape[0], -1), 1)
loss_mean = loss_mean.mean()
loss_dict = {'loss': loss_mean}
return loss_mean, loss_dict
def disable_untrainable_params(self):
pass
# def forward(self, x, batch):
# loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
# loss_mean = loss.mean()
# loss_dict = {"loss": loss_mean}
# return loss_mean, loss_dict
import argparse
import copy
import os
import numpy as np
import torch
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from flow_video import FlowEngine
from torchvision.io.video import write_video
from utils import (decode, disable_all_init, prepare_input, save_mem_decode,
save_memory_encode_first_stage, seed_everything)
from sat import mpu
from sat.model.base_model import get_model
disable_all_init()
def init_model(model, second_model, args, second_args):
share_cache = dict()
second_share_cache = dict()
if hasattr(args, 'share_cache_config'):
for k, v in args.share_cache_config.items():
share_cache[k] = v
if hasattr(second_args, 'share_cache_config'):
for k, v in second_args.share_cache_config.items():
second_share_cache[k] = v
for n, m in model.named_modules():
m.share_cache = share_cache
if hasattr(m, 'register_new_modules'):
m.register_new_modules()
for n, m in second_model.named_modules():
m.share_cache = second_share_cache
if hasattr(m, 'register_new_modules'):
m.register_new_modules()
if os.environ.get('SKIP_LOAD', None) is not None:
print('skip load for speed debug')
else:
weight_path = args.inf_ckpt
weight = torch.load(weight_path, map_location='cpu')
if 'model.diffusion_model.mixins.pos_embed.freqs_sin' in weight[
'module']:
del weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_sin']
del weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_cos']
msg = model.load_state_dict(weight['module'], strict=False)
print(msg)
second_weight_path = args.inf_ckpt2
second_weight = torch.load(second_weight_path, map_location='cpu')
if 'model.diffusion_model.mixins.pos_embed.freqs_sin' in second_weight[
'module']:
del second_weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_sin']
del second_weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_cos']
second_msg = second_model.load_state_dict(second_weight['module'],
strict=False)
print(second_msg)
for n, m in model.named_modules():
if hasattr(m, 'merge_lora'):
m.merge_lora()
print(f'merge lora of {n}')
def get_first_results(model, text, num_frames, H, W):
"""Get first Stage results.
Args:
model (nn.Module): first stage model.
text (str): text prompt
num_frames (int): number of frames
H (int): height of the first stage results
W (int): width of the first stage results
Returns:
Tensor: first stage video.
"""
device = 'cuda'
T = 1 + (num_frames - 1) // 4
F = 8
motion_text_prefix = [
'very low motion,',
'low motion,',
'medium motion,',
'high motion,',
'very high motion,',
]
neg_prompt = ''
pos_prompt = ''
with torch.no_grad():
model.to('cuda')
input_negative_prompt = motion_text_prefix[
0] + ', ' + motion_text_prefix[1] + neg_prompt
c, uc = prepare_input(text,
model,
T,
negative_prompt=input_negative_prompt,
pos_prompt=pos_prompt)
with torch.no_grad(), torch.amp.autocast(enabled=True,
device_type='cuda',
dtype=torch.bfloat16):
samples_z = model.sample(
c,
uc=uc,
batch_size=1,
shape=(T, 16, H // F, W // F),
num_steps=model.share_cache.get('first_sample_step', None),
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
model.to('cpu')
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
samples = decode(first_stage_model, latent)
model.to('cpu')
return samples
def get_second_results(model, text, first_stage_samples, num_frames):
"""Get second Stage results.
Args:
model (nn.Module): second stage model.
text (str): text prompt
first_stage_samples (Tensor): first stage results
num_frames (int): number of frames
Returns:
Tensor: second stage results.
"""
t, h, w, c = first_stage_samples.shape
first_stage_samples = first_stage_samples[:num_frames]
first_stage_samples = (first_stage_samples / 255.)
first_stage_samples = (first_stage_samples - 0.5) / 0.5
target_size = model.share_cache.get('target_size', None)
if target_size is None:
upscale_factor = model.share_cache.get('upscale_factor', 8)
H = int(h * upscale_factor) // 16 * 16
W = int(w * upscale_factor) // 16 * 16
else:
H, W = target_size
H = H // 16 * 16
W = W // 16 * 16
first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')
ref_x = torch.nn.functional.interpolate(first_stage_samples,
size=(H, W),
mode='bilinear',
align_corners=False,
antialias=True)
ref_x = ref_x[:num_frames][None]
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
first_stage_model = model.first_stage_model
print(f'start encoding first stage results to high resolution')
with torch.no_grad():
first_stage_dtype = next(model.first_stage_model.parameters()).dtype
model.first_stage_model.cuda()
ref_x = save_memory_encode_first_stage(
ref_x.contiguous().to(first_stage_dtype).cuda(), model)
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
ref_x = ref_x.to(model.dtype)
print(f'finish encoding first stage results, and starting stage II')
device = 'cuda'
model.to(device)
motion_text_prefix = [
'very low motion,',
'low motion,',
'medium motion,',
'high motion,',
'very high motion,',
]
pos_prompt = None
input_negative_prompt = None
text = 'medium motion,' + text
c, uc = prepare_input(text,
model,
num_frames,
negative_prompt=input_negative_prompt,
pos_prompt=pos_prompt)
T = 1 + (num_frames - 1) // 4
F = 8
with torch.no_grad(), torch.amp.autocast(enabled=True,
device_type='cuda',
dtype=torch.bfloat16):
samples_z = model.sample(
ref_x,
c,
uc=uc,
batch_size=1,
shape=(T, 16, H // F, W // F),
num_steps=model.share_cache.get('sample_step', 5),
method='euler',
cfg=model.share_cache.get('cfg', 7.5),
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
model.to('cpu')
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
print(f'start spatiotemporal slice decoding')
samples = save_mem_decode(first_stage_model, latent)
print(f'finish spatiotemporal slice decoding')
model.to('cpu')
return samples
def sampling_main(args, second_args, model_cls, second_model_cls):
local_rank = int(os.environ.get('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank)
second_model = get_model(second_args, second_model_cls)
model = get_model(args, model_cls)
init_model(model, second_model, args, second_args)
model.eval()
second_model.eval()
rank, world_size = mpu.get_data_parallel_rank(
), mpu.get_data_parallel_world_size()
rank, world_size = mpu.get_data_parallel_rank(
), mpu.get_data_parallel_world_size()
print('rank and world_size', rank, world_size)
text_file = args.input_file
num_sample_perprompt = 1
with open(text_file) as fin:
all_prompt = []
for single_line in fin:
all_prompt.append(single_line.strip())
print(f'load from {text_file} with {len(all_prompt)}')
image_size = [270, 480]
image_size = (image_size[0] // 16 * 16, image_size[1] // 16 * 16)
# second_img_size = [1080, 1920]
second_img_size = [270, 480]
second_img_size = (second_img_size[0] // 16 * 16,
second_img_size[1] // 16 * 16)
num_frames = 49
second_num_frames = 49
# 6-8
model.share_cache['cfg'] = 8
second_model.share_cache['target_size'] = second_img_size
# range from 650 to 750
second_model.share_cache['ref_noise_step'] = 675
second_model.share_cache['sample_ref_noise_step'] = 675
# range from 2 to 3.5
second_model.share_cache['shift_t'] = 2.5
# range from 4 to 6
second_model.share_cache['sample_step'] = 5
# range from 10 to 13
second_model.share_cache['cfg'] = 13
second_model.share_cache.pop('ref_noise_step_range', None)
second_model.share_cache['time_size_embedding'] = True
_, H, W, _, _ = num_frames, image_size[0], image_size[
1], args.latent_channels, 8
save_dir = f''''''
if args.output_dir:
save_dir = args.output_dir
print(save_dir)
os.makedirs(save_dir, exist_ok=True)
all_ids = range(len(all_prompt))
local_ids = all_ids[rank::world_size]
for enu_index in local_ids:
text = all_prompt[enu_index]
print(f'rank {rank} processing {enu_index}')
for inter_index in range(num_sample_perprompt):
seed_everything(enu_index + inter_index * 1000)
seed = enu_index + inter_index * 1000
first_stage_samples = get_first_results(model, text, num_frames, H,
W)
file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}.mp4'
second_file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}_second.mp4'
joint_file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}_joint.mp4'
print(f'save to {file_name}')
write_video(filename=file_name,
fps=8,
video_array=first_stage_samples,
options={'crf': '5'})
if not args.skip_second:
second_stage_samples = get_second_results(
second_model, text, first_stage_samples, second_num_frames)
write_video(filename=second_file_name,
fps=8,
video_array=second_stage_samples.cpu(),
options={'crf': '5'})
# save joint video
part_first_stage = first_stage_samples[:second_num_frames]
target_h, target_w = second_stage_samples.shape[
1], second_stage_samples.shape[2]
part_first_stage = torch.nn.functional.interpolate(
part_first_stage.permute(0, 3, 1, 2).contiguous(),
size=(target_h, target_w),
mode='bilinear',
align_corners=False,
antialias=True)
part_first_stage = part_first_stage.permute(0, 2, 3,
1).contiguous()
joint_video = torch.cat(
[part_first_stage.cpu(),
second_stage_samples.cpu()],
dim=-2)
write_video(filename=joint_file_name,
fps=8,
video_array=joint_video.cpu(),
options={'crf': '15'})
if __name__ == '__main__':
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
second_args_list = copy.deepcopy(args_list)
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
second_args_list[1] = args.second[0]
second_args = get_args(second_args_list)
second_args = argparse.Namespace(**vars(second_args), **vars(known))
del second_args.deepspeed_config
second_args.model_config.first_stage_config.params.cp_size = 1
second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1
second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False
second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
sampling_main(args,
second_args,
model_cls=SATVideoDiffusionEngine,
second_model_cls=FlowEngine)
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
from base_model import BaseModel
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
from sgm.util import instantiate_from_config
from torch import nn
from sat.model.base_model import non_conflict
from sat.model.mixins import BaseMixin
from sat.mpu.layers import ColumnParallelLinear
from sat.ops.layernorm import LayerNorm, RMSNorm
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
class ImagePatchEmbeddingMixin(BaseMixin):
def __init__(
self,
in_channels,
hidden_size,
patch_size,
bias=True,
text_hidden_size=None,
):
super().__init__()
self.proj = nn.Conv2d(in_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size,
bias=bias)
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else:
self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs):
# now is 3d patch
images = kwargs['images'] # (b,t,c,h,w)
B, T = images.shape[:2]
emb = images.view(-1, *images.shape[2:])
emb = self.proj(emb) # ((b t),d,h/2,w/2)
emb = emb.view(B, T, *emb.shape[1:])
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
emb = rearrange(emb, 'b t n d -> b (t n) d')
if self.text_proj is not None:
text_emb = self.text_proj(kwargs['encoder_outputs'])
emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
emb = emb.contiguous()
return emb # (b,n_t+t*n_i,d)
def reinit(self, parent_model=None):
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.proj.bias, 0)
del self.transformer.word_embeddings
import copy
import os
def get_3d_sincos_pos_embed(
embed_dim,
grid_height,
grid_width,
t_size,
cls_token=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
embed_dim_temporal = embed_dim // 4
# spatial
grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
embed_dim_spatial, grid)
# temporal
grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
embed_dim_temporal, grid_t)
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal,
grid_height * grid_width,
axis=1) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size,
axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial],
axis=-1)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D]
def get_2d_sincos_pos_embed(embed_dim,
grid_height,
grid_width,
cls_token=False,
extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_height, dtype=np.float32)
grid_w = np.arange(grid_width, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Basic2DPositionEmbeddingMixin(BaseMixin):
def __init__(self,
height,
width,
compressed_num_frames,
hidden_size,
text_length=0):
super().__init__()
self.height = height
self.width = width
self.spatial_length = height * width
self.pos_embedding = nn.Parameter(torch.zeros(
1, int(text_length + self.spatial_length), int(hidden_size)),
requires_grad=False)
def position_embedding_forward(self, position_ids, **kwargs):
return self.pos_embedding
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1],
self.height, self.width)
self.pos_embedding.data[:, -self.spatial_length:].copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0))
class Basic3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
text_length=0,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
super().__init__()
self.height = height
self.width = width
self.text_length = text_length
self.compressed_num_frames = compressed_num_frames
self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter(torch.zeros(
1, int(text_length + self.num_patches), int(hidden_size)),
requires_grad=False)
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
def position_embedding_forward(self, position_ids, **kwargs):
if kwargs['images'].shape[1] == 1:
return self.pos_embedding[:, :self.text_length +
self.spatial_length]
return self.pos_embedding[:, :self.text_length + kwargs['seq_length']]
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_3d_sincos_pos_embed(
self.pos_embedding.shape[-1],
self.height,
self.width,
self.compressed_num_frames,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
)
pos_embed = torch.from_numpy(pos_embed).float()
pos_embed = rearrange(pos_embed, 't n d -> (t n) d')
self.pos_embedding.data[:, -self.num_patches:].copy_(pos_embed)
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(
shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)
]), 'invalid dimensions for broadcastable concatentation'
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(
map(lambda t: (t[0], (t[1], ) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(
map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d r -> ... (d r)')
class Rotary3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
h_interp_ratio=1.0,
w_interp_ratio=1.0,
t_interp_ratio=1.0,
rot_v=False,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
print(f'theta is {theta}')
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
freqs_t = 1.0 / (theta**(
torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta**(
torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta**(
torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w))
self.compressed_num_frames = compressed_num_frames
self.height = height
self.width = width
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
grid_h = torch.arange(height, dtype=torch.float32)
grid_w = torch.arange(width, dtype=torch.float32)
if t_interp_ratio > 1.0:
print(f't_interp_ratio is {t_interp_ratio}')
grid_t = grid_t / t_interp_ratio
if h_interp_ratio > 1.0:
print(f'h_interp_ratio is {h_interp_ratio}')
grid_h = grid_h / h_interp_ratio
if w_interp_ratio > 1.0:
print(f'w_interp_ratio is {w_interp_ratio}')
grid_w = grid_w / w_interp_ratio
freqs_t = torch.einsum('..., f -> ... f', grid_t, freqs_t)
freqs_h = torch.einsum('..., f -> ... f', grid_h, freqs_h)
freqs_w = torch.einsum('..., f -> ... f', grid_w, freqs_w)
freqs_t = repeat(freqs_t, '... n -> ... (n r)', r=2)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :],
freqs_w[None, None, :, :]),
dim=-1)
freqs = rearrange(freqs, 't h w d -> (t h w) d')
freqs = freqs.contiguous()
freqs_sin = freqs.sin()
freqs_cos = freqs.cos()
self.register_buffer('freqs_sin', freqs_sin)
self.register_buffer('freqs_cos', freqs_cos)
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(
1, num_patches, int(hidden_size)),
requires_grad=True)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
seq_len = t.shape[2]
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
return None
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
**kwargs,
):
attention_fn_default = HOOKS_DEFAULT['attention_fn']
query_layer[:, :, self.text_length:] = self.rotary(
query_layer[:, :, self.text_length:])
key_layer[:, :, self.text_length:] = self.rotary(
key_layer[:, :, self.text_length:])
if self.rot_v:
value_layer[:, :, self.text_length:] = self.rotary(
value_layer[:, :, self.text_length:])
return attention_fn_default(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
"""
if rope_position_ids is not None:
assert NotImplementedError
# do pix2struct unpatchify
L = x.shape[1]
x = x.reshape(shape=(x.shape[0], L, p, p, c))
x = torch.einsum('nlpqc->ncplq', x)
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
else:
b = x.shape[0]
imgs = rearrange(x,
'b (t h w) (c p q) -> b t c (h p) (w q)',
b=b,
h=h,
w=w,
c=c,
p=p,
q=p)
return imgs
class FinalLayerMixin(BaseMixin):
def __init__(
self,
hidden_size,
time_embed_dim,
patch_size,
out_channels,
latent_width,
latent_height,
elementwise_affine,
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size,
elementwise_affine=elementwise_affine,
eps=1e-6)
self.linear = nn.Linear(hidden_size,
patch_size * patch_size * out_channels,
bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.spatial_length = latent_width * latent_height // patch_size**2
self.latent_width = latent_width
self.latent_height = latent_height
def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs['text_length']:, :], kwargs[
'emb'] # x:(b,(t n),d)
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
if hasattr(self, 'share_cache') and 'mode' in self.share_cache:
mode = self.share_cache['mode']
if mode == 'r':
t, h, w = self.share_cache['shape_info']
target_h = h
target_w = w
else:
assert mode == 'w'
t, h, w = self.share_cache['ref_shape_info']
target_h = h
target_w = w
elif hasattr(self, 'share_cache') and 'shape_info' in self.share_cache:
t, h, w = self.share_cache['shape_info']
target_h = h
target_w = w
else:
target_h = self.latent_height // self.patch_size
target_w = self.latent_width // self.patch_size
return unpatchify(
x,
c=self.out_channels,
p=self.patch_size,
w=target_w,
h=target_h,
rope_position_ids=kwargs.get('rope_position_ids', None),
**kwargs,
)
def reinit(self, parent_model=None):
nn.init.xavier_uniform_(self.linear.weight)
nn.init.constant_(self.linear.bias, 0)
class SwiGLUMixin(BaseMixin):
def __init__(self, num_layers, in_features, hidden_features, bias=False):
super().__init__()
self.w2 = nn.ModuleList([
ColumnParallelLinear(
in_features,
hidden_features,
gather_output=False,
bias=bias,
module=self,
name='dense_h_to_4h_gate',
) for i in range(num_layers)
])
def mlp_forward(self, hidden_states, **kw_args):
x = hidden_states
origin = self.transformer.layers[kw_args['layer_id']].mlp
x1 = origin.dense_h_to_4h(x)
x2 = self.w2[kw_args['layer_id']](x)
hidden = origin.activation_func(x2) * x1
x = origin.dense_4h_to_h(hidden)
return x
class AdaLNMixin(BaseMixin):
def __init__(
self,
width,
height,
hidden_size,
num_layers,
time_embed_dim,
compressed_num_frames,
qk_ln=True,
hidden_size_head=None,
elementwise_affine=True,
):
super().__init__()
self.num_layers = num_layers
self.width = width
self.height = height
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList([
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim,
12 * hidden_size))
for _ in range(num_layers)
])
self.qk_ln = qk_ln
if qk_ln:
self.query_layernorm_list = nn.ModuleList([
LayerNorm(hidden_size_head,
eps=1e-6,
elementwise_affine=elementwise_affine)
for _ in range(num_layers)
])
self.key_layernorm_list = nn.ModuleList([
LayerNorm(hidden_size_head,
eps=1e-6,
elementwise_affine=elementwise_affine)
for _ in range(num_layers)
])
def layer_forward(
self,
hidden_states,
mask,
*args,
**kwargs,
):
# spatial attn here
text_length = kwargs['text_length']
# hidden_states (b,(n_t+t*n_i),d)
text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
layer = self.transformer.layers[kwargs['layer_id']]
# if os.environ.get("DEBUGINFO", False):
# print(f"in forward layer_id is {kwargs['layer_id']}", flush=True)
adaLN_modulation = self.adaLN_modulations[kwargs['layer_id']]
emb = kwargs['emb']
# if "size_emb" in self.share_cache:
# size_emb = self.share_cache["size_emb"]
# emb = emb + size_emb
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
text_shift_msa,
text_scale_msa,
text_gate_msa,
text_shift_mlp,
text_scale_mlp,
text_gate_mlp,
) = adaLN_modulation(emb).chunk(12, dim=1)
gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
gate_msa.unsqueeze(1),
gate_mlp.unsqueeze(1),
text_gate_msa.unsqueeze(1),
text_gate_mlp.unsqueeze(1),
)
# self full attention (b,(t n),d)
img_attention_input = layer.input_layernorm(img_hidden_states)
text_attention_input = layer.input_layernorm(text_hidden_states)
img_attention_input = modulate(img_attention_input, shift_msa,
scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa,
text_scale_msa)
attention_input = torch.cat(
(text_attention_input, img_attention_input),
dim=1) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
if self.transformer.layernorm_order == 'sandwich':
text_attention_output = layer.third_layernorm(
text_attention_output)
img_attention_output = layer.third_layernorm(img_attention_output)
img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
# mlp (b,(t n),d)
img_mlp_input = layer.post_attention_layernorm(
img_hidden_states) # vision (b,(t n),d)
text_mlp_input = layer.post_attention_layernorm(
text_hidden_states) # language (b,n,d)
img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
text_mlp_input = modulate(text_mlp_input, text_shift_mlp,
text_scale_mlp)
mlp_input = torch.cat((text_mlp_input, img_mlp_input),
dim=1) # (b,(n_t+t*n_i),d
mlp_output = layer.mlp(mlp_input, **kwargs)
img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
if self.transformer.layernorm_order == 'sandwich':
text_mlp_output = layer.fourth_layernorm(text_mlp_output)
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
hidden_states = torch.cat((text_hidden_states, img_hidden_states),
dim=1)
if 'scale_embedding' in self.share_cache:
scaler_norm_layer = self.share_cache[
f'''scale_norm_layer_{kwargs["layer_id"]}''']
hidden_states = scaler_norm_layer(hidden_states, emb)
# (b,(n_t+t*n_i),d)
return hidden_states
def reinit(self, parent_model=None):
for layer in self.adaLN_modulations:
nn.init.constant_(layer[-1].weight, 0)
nn.init.constant_(layer[-1].bias, 0)
@non_conflict
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
old_impl=attention_fn_default,
**kwargs,
):
self.share_cache['temp_layer_id'] = kwargs['layer_id']
if self.qk_ln:
query_layernorm = self.query_layernorm_list[kwargs['layer_id']]
key_layernorm = self.key_layernorm_list[kwargs['layer_id']]
query_layer = query_layernorm(query_layer)
key_layer = key_layernorm(key_layer)
# old_impl is attention_fn of Rotary3DPositionEmbeddingMixin
return old_impl(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
str_to_dtype = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16
}
class DiffusionTransformer(BaseModel):
def __init__(
self,
transformer_args,
num_frames,
time_compressed_rate,
latent_width,
latent_height,
patch_size,
in_channels,
out_channels,
hidden_size,
num_layers,
num_attention_heads,
elementwise_affine,
time_embed_dim=None,
num_classes=None,
modules={},
input_time='adaln',
adm_in_channels=None,
parallel_output=True,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
zero_init_y_embed=False,
**kwargs,
):
self.latent_width = latent_width
self.latent_height = latent_height
self.patch_size = patch_size
self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // patch_size**2
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.input_time = input_time
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.is_decoder = transformer_args.is_decoder
self.elementwise_affine = elementwise_affine
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4
self.zero_init_y_embed = zero_init_y_embed
try:
self.dtype = str_to_dtype[kwargs.pop('dtype')]
except:
self.dtype = torch.float32
if use_SwiGLU:
kwargs['activation_func'] = F.silu
elif 'activation_func' not in kwargs:
approx_gelu = nn.GELU(approximate='tanh')
kwargs['activation_func'] = approx_gelu
if use_RMSNorm:
kwargs['layernorm'] = RMSNorm
else:
kwargs['layernorm'] = partial(
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size
transformer_args.num_attention_heads = num_attention_heads
transformer_args.parallel_output = parallel_output
super().__init__(args=transformer_args, transformer=None, **kwargs)
module_configs = modules
self._build_modules(module_configs)
if use_SwiGLU:
self.add_mixin('swiglu',
SwiGLUMixin(num_layers,
hidden_size,
self.inner_hidden_size,
bias=False),
reinit=True)
def _build_modules(self, module_configs):
model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
elif self.num_classes == 'continuous':
print('setting up linear c_adm embedding layer')
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == 'timestep':
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == 'sequential':
assert self.adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(self.adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
))
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else:
raise ValueError()
pos_embed_config = module_configs['pos_embed_config']
self.add_mixin(
'pos_embed',
instantiate_from_config(
pos_embed_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
compressed_num_frames=(self.num_frames - 1) //
self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
),
reinit=True,
)
patch_embed_config = module_configs['patch_embed_config']
self.add_mixin(
'patch_embed',
instantiate_from_config(
patch_embed_config,
patch_size=self.patch_size,
hidden_size=self.hidden_size,
in_channels=self.in_channels,
),
reinit=True,
)
if self.input_time == 'adaln':
adaln_layer_config = module_configs['adaln_layer_config']
self.add_mixin(
'adaln_layer',
instantiate_from_config(
adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) //
self.time_compressed_rate + 1,
hidden_size_head=self.hidden_size //
self.num_attention_heads,
time_embed_dim=self.time_embed_dim,
elementwise_affine=self.elementwise_affine,
),
)
else:
raise NotImplementedError
final_layer_config = module_configs['final_layer_config']
self.add_mixin(
'final_layer',
instantiate_from_config(
final_layer_config,
hidden_size=self.hidden_size,
patch_size=self.patch_size,
out_channels=self.out_channels,
time_embed_dim=self.time_embed_dim,
latent_width=self.latent_width,
latent_height=self.latent_height,
elementwise_affine=self.elementwise_affine,
),
reinit=True,
)
if 'lora_config' in module_configs:
lora_config = module_configs['lora_config']
self.add_mixin('lora',
instantiate_from_config(lora_config,
layer_num=self.num_layers),
reinit=True)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
if 'ref_x' in self.share_cache:
ref_x = self.share_cache['ref_x']
if ref_x.dtype != self.dtype:
ref_x = ref_x.to(self.dtype)
self.share_cache['ref_x'] = ref_x
# This is not use in inference
if 'concat_images' in kwargs and kwargs['concat_images'] is not None:
if kwargs['concat_images'].shape[0] != x.shape[0]:
concat_images = kwargs['concat_images'].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs['concat_images']
x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == (
self.num_classes is not None
), 'must specify y if and only if the model is class-conditional'
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
# assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
self.share_cache['shape_info'] = (t, h // (self.patch_size),
w // (self.patch_size))
self.share_cache['timesteps'] = timesteps
kwargs['seq_length'] = t * h * w // (self.patch_size**2)
kwargs['images'] = x
kwargs['emb'] = emb
kwargs['encoder_outputs'] = context
kwargs['text_length'] = context.shape[1]
kwargs['input_ids'] = kwargs['position_ids'] = kwargs[
'attention_mask'] = torch.ones((1, 1)).to(x.dtype)
output = super().forward(**kwargs)[0]
return output
class RefDiffusionTransformer(DiffusionTransformer):
def register_new_modules(self):
all_layers = []
for n, m in self.named_modules():
if hasattr(m, 'attention'):
all_layers.append(m)
for m in all_layers:
m.ref_query_key_value = copy.deepcopy(m.attention.query_key_value)
m.ref_dense = copy.deepcopy(m.attention.dense)
m.ref_attention_dropout = copy.deepcopy(
m.attention.attention_dropout)
m.ref_output_dropout = copy.deepcopy(m.attention.output_dropout)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
assert (y is not None) == (
self.num_classes is not None
), 'must specify y if and only if the model is class-conditional'
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
ref_t_emb = timestep_embedding(torch.zeros_like(timesteps),
self.model_channels,
repeat_only=False,
dtype=self.dtype)
emb = self.time_embed(t_emb)
ref_t_emb = self.time_embed(ref_t_emb)
assert self.num_classes is None
self.share_cache['timesteps'] = timesteps
self.share_cache['shape_info'] = (t, h // (self.patch_size),
w // (self.patch_size))
ref_x = self.share_cache['ref_x']
ref_b, ref_t, ref_d, ref_h, ref_w = ref_x.shape
self.share_cache['ref_shape_info'] = (ref_t,
ref_h // (self.patch_size),
ref_w // (self.patch_size))
idx = kwargs.pop('idx')
kwargs['seq_length'] = t * h * w // (self.patch_size**2)
kwargs['images'] = x
kwargs['emb'] = emb
kwargs['encoder_outputs'] = context
kwargs['text_length'] = context.shape[1]
kwargs['input_ids'] = kwargs['position_ids'] = kwargs[
'attention_mask'] = torch.ones((1, 1)).to(x.dtype)
ref_kwargs = dict()
ref_kwargs['seq_length'] = ref_t * ref_h * ref_w // (self.patch_size**
2)
ref_kwargs['images'] = ref_x
ref_kwargs['emb'] = ref_t_emb
ref_kwargs['encoder_outputs'] = context
ref_kwargs['text_length'] = context.shape[1]
ref_kwargs['input_ids'] = ref_kwargs['position_ids'] = ref_kwargs[
'attention_mask'] = torch.ones((1, 1)).to(x.dtype)
self.share_cache['mode'] = 'w'
super(DiffusionTransformer, self).forward(**ref_kwargs)[0]
self.share_cache['mode'] = 'r'
output = super(DiffusionTransformer, self).forward(**kwargs)[0]
return output
import copy
import random
import torch
from dit_video_concat import (DiffusionTransformer,
Rotary3DPositionEmbeddingMixin, broadcat,
rotate_half)
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.util import timestep_embedding
from torch import nn
from sat.transformer_defaults import HOOKS_DEFAULT
class ScaleCropRotary3DPositionEmbeddingMixin(Rotary3DPositionEmbeddingMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
h_interp_ratio=1.0,
w_interp_ratio=1.0,
t_interp_ratio=1.0,
rot_v=False,
learnable_pos_embed=False,
):
super(Rotary3DPositionEmbeddingMixin, self).__init__()
self.rot_v = rot_v
print(f'theta is {theta}')
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
self.freqs_t = (1.0 / (theta**(
torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t))).cuda()
self.freqs_h = (1.0 / (theta**(
torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h))).cuda()
self.freqs_w = (1.0 / (theta**(
torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w))).cuda()
self.compressed_num_frames = compressed_num_frames
self.height = height
self.width = width
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(
1, num_patches, int(hidden_size)),
requires_grad=True)
else:
self.pos_embedding = None
def online_sin_cos(self, real_t, real_h, real_w, dy_interpolation=None):
grid_t = torch.arange(real_t, dtype=torch.float32, device='cuda')
grid_h = torch.arange(real_h, dtype=torch.float32, device='cuda')
grid_w = torch.arange(real_w, dtype=torch.float32, device='cuda')
freqs_t = self.freqs_t
freqs_h = self.freqs_h
freqs_w = self.freqs_w
freqs_t = torch.einsum('..., f -> ... f', grid_t, freqs_t)
freqs_h = torch.einsum('..., f -> ... f', grid_h, freqs_h)
freqs_w = torch.einsum('..., f -> ... f', grid_w, freqs_w)
freqs_t = repeat(freqs_t, '... n -> ... (n r)', r=2)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :],
freqs_w[None, None, :, :]),
dim=-1)
freqs = rearrange(freqs, 't h w d -> (t h w) d')
temp_layer_id = self.share_cache['temp_layer_id']
emb = self.share_cache['emb']
if f'rope_layer_{temp_layer_id}' in self.share_cache:
m = self.share_cache[f'rope_layer_{temp_layer_id}']
dy_interpolation = m(emb)
else:
dy_interpolation = freqs.new_ones(emb.shape[0], freqs.shape[-1])
b, dim = dy_interpolation.shape
dy_interpolation = dy_interpolation[:, None]
freqs = freqs[None]
freqs = freqs.repeat(b, 1, 1)
if dy_interpolation.shape[-1] != freqs.shape[-1]:
freqs[...,-dy_interpolation.shape[-1]:] = \
freqs[...,-dy_interpolation.shape[-1]:] * dy_interpolation
else:
freqs = freqs * dy_interpolation
freqs_sin = torch.sin(freqs)
freqs_cos = torch.cos(freqs)
return freqs_cos, freqs_sin
def rotary(self, t, **kwargs):
if 'freqs_cos' in self.share_cache:
freqs_cos = self.share_cache['freqs_cos']
freqs_sin = self.share_cache['freqs_sin']
else:
real_t, real_h, real_w = self.share_cache['shape_info']
freqs_cos, freqs_sin = self.online_sin_cos(real_t,
real_h,
real_w,
dy_interpolation=None)
freqs_cos = freqs_cos.unsqueeze(1)
freqs_cos = freqs_cos.to(t.dtype)
freqs_sin = freqs_sin.unsqueeze(1)
freqs_sin = freqs_sin.to(t.dtype)
self.share_cache['freqs_cos'] = freqs_cos
self.share_cache['freqs_sin'] = freqs_sin
return t * freqs_cos + rotate_half(t) * freqs_sin
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
**kwargs,
):
attention_fn_default = HOOKS_DEFAULT['attention_fn']
img_query_layer = self.rotary(query_layer[:, :, self.text_length:])
query_layer = torch.cat(
[query_layer[:, :, :self.text_length], img_query_layer], dim=2)
query_layer = query_layer.to(value_layer.dtype)
img_key_layer = self.rotary(key_layer[:, :, self.text_length:])
key_layer = torch.cat(
[key_layer[:, :, :self.text_length], img_key_layer], dim=2)
key_layer = key_layer.to(value_layer.dtype)
if self.rot_v:
value_layer[:, :, self.text_length:] = self.rotary(
value_layer[:, :, self.text_length:])
return attention_fn_default(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
from sat.model.finetune.lora2 import *
class ResLoraMixin(LoraMixin):
def reinit(self, parent_model):
for i in self.layer_range:
print_rank0(f'replacing layer {i} attention with lora')
parent_model.transformer.layers[
i].attention.dense = replace_linear_with_lora(
parent_model.transformer.layers[i].attention.dense,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=None)
parent_model.transformer.layers[
i].attention.query_key_value = replace_linear_with_lora(
parent_model.transformer.layers[i].attention.
query_key_value,
parent_model.transformer.layers[i].attention.stride,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=None
if not parent_model.transformer.num_multi_query_heads else
parent_model.transformer.layers[i].attention.
inner_hidden_size + parent_model.transformer.layers[i].
attention.hidden_size_per_attention_head * parent_model.
transformer.layers[i].attention.num_multi_query_heads * 2)
if self.cross_attention and parent_model.transformer.layers[
i].is_decoder:
print_rank0(f'replacing layer {i} cross attention with lora')
kv_size = parent_model.transformer.layers[
i].cross_attention.inner_hidden_size * 2 if not parent_model.transformer.cross_num_multi_query_heads else parent_model.transformer.layers[
i].cross_attention.hidden_size_per_attention_head * parent_model.transformer.layers[
i].cross_attention.cross_num_multi_query_heads * 2
parent_model.transformer.layers[
i].cross_attention.dense = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
dense,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.layers[i].
cross_attention.inner_hidden_size,
out_size=parent_model.transformer.hidden_size)
parent_model.transformer.layers[
i].cross_attention.query = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
query,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=parent_model.transformer.layers[i].
cross_attention.inner_hidden_size)
parent_model.transformer.layers[
i].cross_attention.key_value = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
key_value,
2,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.layers[i].
cross_attention.cross_attn_hidden_size,
out_size=kv_size)
for m in parent_model.mixins.adaln_layer.adaLN_modulations:
m[1] = replace_linear_with_lora(m[1],
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=512,
out_size=36864)
def merge_lora(self):
for i in self.layer_range:
print_rank0(f'merge layer {i} lora attention back to linear')
self.transformer.layers[i].attention.dense = merge_linear_lora(
self.transformer.layers[i].attention.dense)
self.transformer.layers[
i].attention.query_key_value = merge_linear_lora(
self.transformer.layers[i].attention.query_key_value)
if self.cross_attention and self.transformer.layers[i].is_decoder:
print_rank0(
f'merge layer {i} lora cross attention back to linear')
self.transformer.layers[
i].cross_attention.dense = merge_linear_lora(
self.transformer.layers[i].cross_attention.dense)
self.transformer.layers[
i].cross_attention.query = merge_linear_lora(
self.transformer.layers[i].cross_attention.query)
self.transformer.layers[
i].cross_attention.key_value = merge_linear_lora(
self.transformer.layers[i].cross_attention.key_value)
class SMALLDiffusionTransformer(DiffusionTransformer):
def register_new_modules(self):
if 'sample_ref_noise_step' in self.share_cache:
self.ref_step_time_embedding = copy.deepcopy(self.time_embed)
# zero init last linear in the self.ref_step_time_embedding
for n, p in self.ref_step_time_embedding[-1].named_parameters():
nn.init.constant_(p, 0)
p.requires_grad = True
if 'time_size_embedding' in self.share_cache:
self.time_size_embedding = copy.deepcopy(self.time_embed)
# zero init the fuse linear
for n, p in self.time_size_embedding.named_parameters():
nn.init.constant_(p, 0)
p.requires_grad = True
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
if 'ref_x' in self.share_cache:
ref_x = self.share_cache['ref_x']
if ref_x.dtype != self.dtype:
ref_x = ref_x.to(self.dtype)
self.share_cache['ref_x'] = ref_x
# This is not use in inference
if 'concat_images' in kwargs and kwargs['concat_images'] is not None:
if kwargs['concat_images'].shape[0] != x.shape[0]:
concat_images = kwargs['concat_images'].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs['concat_images']
x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == (
self.num_classes is not None
), 'must specify y if and only if the model is class-conditional'
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
emb = self.time_embed(t_emb)
if 'time_size_embedding' in self.share_cache:
num_t = torch.zeros_like(timesteps).fill_(int(t))
time_size_emb = timestep_embedding(num_t,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
time_size_emb = self.time_size_embedding(time_size_emb)
emb = emb + time_size_emb
if 'sample_ref_noise_step' in self.share_cache:
print(
f'''sample_ref_noise_step {self.share_cache["sample_ref_noise_step"]}'''
)
# bf 16
ref_time_step = copy.deepcopy(timesteps).fill_(
self.share_cache['sample_ref_noise_step'])
ref_step_time_emb = timestep_embedding(ref_time_step,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
ref_step_time_emb = self.ref_step_time_embedding(ref_step_time_emb)
if not self.training:
print(f'{ref_time_step} get {ref_step_time_emb.sum()}')
emb = emb + ref_step_time_emb
if self.num_classes is not None:
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
self.share_cache['shape_info'] = (t, h // (self.patch_size),
w // (self.patch_size))
self.share_cache['timesteps'] = int(timesteps[0])
self.share_cache['emb'] = emb
kwargs['seq_length'] = t * h * w // (self.patch_size**2)
kwargs['images'] = x
kwargs['emb'] = emb
kwargs['encoder_outputs'] = context
kwargs['text_length'] = context.shape[1]
kwargs['input_ids'] = kwargs['position_ids'] = kwargs[
'attention_mask'] = torch.ones((1, 1)).to(x.dtype)
output = super(DiffusionTransformer, self).forward(**kwargs)[0]
return output
import math
import time
from functools import partial
import torch
import torch.nn as nn
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.loss import StandardDiffusionLoss
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (append_dims, default, disabled_train, get_obj_from_str,
instantiate_from_config)
from torch import nn
from torchdiffeq import odeint
class FlowEngine(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
log_keys = model_config.get('log_keys', None)
input_key = model_config.get('input_key', 'mp4')
network_config = model_config.get('network_config', None)
network_wrapper = model_config.get('network_wrapper', None)
denoiser_config = model_config.get('denoiser_config', None)
sampler_config = model_config.get('sampler_config', None)
conditioner_config = model_config.get('conditioner_config', None)
first_stage_config = model_config.get('first_stage_config', None)
loss_fn_config = model_config.get('loss_fn_config', None)
scale_factor = model_config.get('scale_factor', 1.0)
latent_input = model_config.get('latent_input', False)
disable_first_stage_autocast = model_config.get(
'disable_first_stage_autocast', False)
no_cond_log = model_config.get('disable_first_stage_autocast', False)
not_trainable_prefixes = model_config.get(
'not_trainable_prefixes', ['first_stage_model', 'conditioner'])
compile_model = model_config.get('compile_model', False)
en_and_decode_n_samples_a_time = model_config.get(
'en_and_decode_n_samples_a_time', None)
lr_scale = model_config.get('lr_scale', None)
lora_train = model_config.get('lora_train', False)
self.use_pd = model_config.get('use_pd', False)
self.log_keys = log_keys
self.input_key = input_key
self.not_trainable_prefixes = not_trainable_prefixes
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.lora_train = lora_train
self.noised_image_input = model_config.get('noised_image_input', False)
self.noised_image_all_concat = model_config.get(
'noised_image_all_concat', False)
self.noised_image_dropout = model_config.get('noised_image_dropout',
0.0)
if args.fp16:
dtype = torch.float16
dtype_str = 'fp16'
elif args.bf16:
dtype = torch.bfloat16
dtype_str = 'bf16'
else:
dtype = torch.float32
dtype_str = 'fp32'
self.dtype = dtype
self.dtype_str = dtype_str
network_config['params']['dtype'] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(
default(network_wrapper,
OPENAIUNETWRAPPER))(model,
compile_model=compile_model,
dtype=dtype)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(
sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG))
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(
loss_fn_config) if loss_fn_config is not None else None
self.latent_input = latent_input
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device
def disable_untrainable_params(self):
pass
def reinit(self, parent_model=None):
pass
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def get_input(self, batch):
return batch[self.input_key].to(self.dtype)
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast('cuda',
enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {
'timesteps': len(z[n * n_samples:(n + 1) * n_samples])
}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples:(n + 1) * n_samples], **kwargs)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x, batch):
frame = x.shape[2]
if frame > 1 and self.latent_input:
x = x.permute(0, 2, 1, 3, 4).contiguous()
return x * self.scale_factor # already encoded
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast('cuda',
enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(x[n * n_samples:(n + 1) *
n_samples])
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
@torch.no_grad()
def save_memory_encode_first_stage(self, x, batch):
splits_x = torch.split(x, [13, 12, 12, 12], dim=2)
all_out = []
with torch.autocast('cuda', enabled=False):
for idx, input_x in enumerate(splits_x):
if idx == len(splits_x) - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
out = self.first_stage_model.encode(
input_x.contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache)
all_out.append(out)
z = torch.cat(all_out, dim=2)
z = self.scale_factor * z
return z
def single_function_evaluation(self,
t,
x,
cond=None,
uc=None,
cfg=1,
**kwargs):
start_time = time.time()
# for CFG
x = torch.cat([x] * 2)
t = t.reshape(1).to(x.dtype).to(x.device)
t = torch.cat([t] * 2)
idx = 1000 - (t * 1000)
real_cond = dict()
for k, v in cond.items():
uncond_v = uc[k]
real_cond[k] = torch.cat([v, uncond_v])
vt = self.model(x, t=idx, c=real_cond, idx=idx)
vt, uc_vt = vt.chunk(2)
vt = uc_vt + cfg * (vt - uc_vt)
end_time = time.time()
print(f'single_function_evaluation time at {t}', end_time - start_time)
return vt
@torch.no_grad()
def sample(
self,
ref_x,
cond,
uc,
**sample_kwargs,
):
"""Stage 2 Sampling, start from the first stage results `ref_x`
Args:
ref_x (_type_): Stage1 low resolution video
cond (dict): Dict contains condtion embeddings
uc (dict): Dict contains uncondition embedding
Returns:
Tensor: Secondary stage results
"""
sample_kwargs = sample_kwargs or {}
print('sample_kwargs', sample_kwargs)
# timesteps
num_steps = sample_kwargs.get('num_steps', 4)
t = torch.linspace(0, 1, num_steps + 1,
dtype=ref_x.dtype).to(ref_x.device)
print(self.share_cache['shift_t'])
shift_t = float(self.share_cache['shift_t'])
t = 1 - shift_t * (1 - t) / (1 + (shift_t - 1) * (1 - t))
print('sample:', t)
t = t
single_function_evaluation = partial(self.single_function_evaluation,
cond=cond,
uc=uc,
cfg=sample_kwargs.get('cfg', 1))
ref_noise_step = self.share_cache['sample_ref_noise_step']
print(f'ref_noise_step : {ref_noise_step}')
ref_alphas_cumprod_sqrt = self.loss_fn.sigma_sampler.idx_to_sigma(
torch.zeros(ref_x.shape[0]).fill_(ref_noise_step).long())
ref_alphas_cumprod_sqrt = ref_alphas_cumprod_sqrt.to(ref_x.device)
ori_dtype = ref_x.dtype
ref_noise = torch.randn_like(ref_x)
print('weight', ref_alphas_cumprod_sqrt, flush=True)
ref_noised_input = ref_x * append_dims(ref_alphas_cumprod_sqrt, ref_x.ndim) \
+ ref_noise * append_dims(
(1 - ref_alphas_cumprod_sqrt**2) ** 0.5, ref_x.ndim
)
ref_x = ref_noised_input.to(ori_dtype)
self.share_cache['ref_x'] = ref_x
results = odeint(single_function_evaluation,
ref_x,
t,
method=sample_kwargs.get('method', 'rk4'),
atol=1e-6,
rtol=1e-3)[-1]
return results
class FlowVideoDiffusionLoss(StandardDiffusionLoss):
def __init__(self,
block_scale=None,
block_size=None,
min_snr_value=None,
fixed_frames=0,
**kwargs):
self.fixed_frames = fixed_frames
self.block_scale = block_scale
self.block_size = block_size
self.min_snr_value = min_snr_value
self.schedule = None
super().__init__(**kwargs)
def __call__(self, network, denoiser, conditioner, input, batch):
pass
from .models import AutoencodingEngine
from .util import get_configs_path, instantiate_from_config
__version__ = '0.1.0'
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_lr}')
if n < self.lr_warm_up_steps:
lr = (self.lr_max -
self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps -
self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max -
self.lr_min) * (1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self,
warm_up_steps,
f_min,
f_max,
f_start,
cycle_lengths,
verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(
f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}')
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (
self.f_max[cycle] - self.f_min[cycle]) * (1 +
np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
f'current cycle {cycle}')
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = (self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) *
(self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]))
self.last_f = f
return f
from .autoencoder import AutoencodingEngine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment