Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
import argparse
import datetime
import json
import os
import warnings
import omegaconf
import torch
import torch.distributed
from omegaconf import OmegaConf
from sat import mpu
from sat.arguments import (add_data_args, add_evaluation_args,
add_training_args, set_random_seed)
from sat.helpers import print_rank0
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--base',
type=str,
nargs='*',
help='config for input and saving')
group.add_argument('--second',
type=str,
nargs='*',
help='config for input and saving')
group.add_argument(
'--model-parallel-size',
type=int,
default=1,
help='size of the model parallel. only use if you are an expert.')
group.add_argument('--force-pretrain', action='store_true')
group.add_argument('--device', type=int, default=-1)
group.add_argument('--debug', action='store_true')
group.add_argument('--log-image', type=bool, default=True)
group.add_argument('--inf-ckpt', type=str, default=None)
group.add_argument('--inf-ckpt2', type=str, default=None)
group.add_argument('--skip-second', action='store_true')
group.add_argument('--first-stage-re',
type=int,
default=270,
help='resolution of first stage')
return parser
def add_sampling_config_args(parser):
"""Sampling configurations"""
group = parser.add_argument_group('sampling', 'Sampling Configurations')
group.add_argument('--output-dir', type=str, default='samples')
group.add_argument('--input-dir', type=str, default=None)
group.add_argument('--input-type', type=str, default='cli')
group.add_argument('--input-file', type=str, default='./ht100.txt')
group.add_argument('--final-size', type=int, default=2048)
group.add_argument('--sdedit', action='store_true')
group.add_argument('--grid-num-rows', type=int, default=1)
group.add_argument('--force-inference', action='store_true')
group.add_argument('--lcm_steps', type=int, default=None)
group.add_argument('--sampling-num-frames', type=int, default=32)
group.add_argument('--sampling-num-steps', type=int, default=30)
group.add_argument('--sampling-fps', type=int, default=8)
group.add_argument('--only-save-latents', type=bool, default=False)
group.add_argument('--only-log-video-latents', type=bool, default=False)
group.add_argument('--latent-channels', type=int, default=32)
group.add_argument('--image2video', action='store_true')
return parser
def add_extra_config_args(parser):
group = parser.add_argument_group('joint', 'joint training Configurations')
group.add_argument('--img-iter', type=int, default=0)
group.add_argument('--video-iter', type=int, default=0)
return parser
def get_args(args_list=None, parser=None):
"""Parse all the args."""
if parser is None:
parser = argparse.ArgumentParser(description='sat')
else:
assert isinstance(parser, argparse.ArgumentParser)
parser = add_model_config_args(parser)
parser = add_sampling_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_data_args(parser)
parser = add_extra_config_args(parser)
import deepspeed
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(args_list)
args = process_config_to_args(args)
if not args.train_data:
print_rank0('No training data specified', level='WARNING')
assert (args.train_iters is None) or (
args.epochs is
None), 'only one of train_iters and epochs should be set.'
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0(
'No train_iters (recommended) or epochs specified, use default 10k iters.',
level='WARNING')
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv('WORLD_SIZE', '1'))
if args.local_rank is None:
args.local_rank = int(os.getenv('LOCAL_RANK', '0')) # torchrun
if args.device == -1:
if torch.cuda.device_count() == 0:
args.device = 'cpu'
elif args.local_rank is not None:
args.device = args.local_rank
else:
args.device = args.rank % torch.cuda.device_count()
if args.local_rank != args.device and args.mode != 'inference':
raise ValueError(
'LOCAL_RANK (default 0) and args.device inconsistent. '
'This can only happens in inference mode. '
'Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. ')
if args.rank == 0:
print_rank0(f'using world size: {args.world_size}')
if args.train_data_weights is not None:
assert len(args.train_data_weights) == len(args.train_data)
if args.mode != 'inference': # training with deepspeed
args.deepspeed = True
if args.deepspeed_config is None: # not specified
deepspeed_config_path = os.path.join(
os.path.dirname(__file__), 'training',
f'deepspeed_zero{args.zero_stage}.json')
with open(deepspeed_config_path) as file:
args.deepspeed_config = json.load(file)
override_deepspeed_config = True
else:
override_deepspeed_config = False
assert not (args.fp16 and args.bf16), 'cannot specify both fp16 and bf16.'
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
print_rank0('Automatically set fp16=True to use ZeRO.')
args.fp16 = True
args.bf16 = False
if args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
else:
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
deepspeed_config = args.deepspeed_config
if override_deepspeed_config: # not specify deepspeed_config, use args
if args.fp16:
deepspeed_config['fp16']['enabled'] = True
elif args.bf16:
deepspeed_config['bf16']['enabled'] = True
deepspeed_config['fp16']['enabled'] = False
else:
deepspeed_config['fp16']['enabled'] = False
deepspeed_config[
'train_micro_batch_size_per_gpu'] = args.batch_size
deepspeed_config[
'gradient_accumulation_steps'] = args.gradient_accumulation_steps
optimizer_params_config = deepspeed_config['optimizer']['params']
optimizer_params_config['lr'] = args.lr
optimizer_params_config['weight_decay'] = args.weight_decay
else: # override args with values in deepspeed_config
if args.rank == 0:
print_rank0(
'Will override arguments with manually specified deepspeed_config!'
)
if 'fp16' in deepspeed_config and deepspeed_config['fp16'][
'enabled']:
args.fp16 = True
else:
args.fp16 = False
if 'bf16' in deepspeed_config and deepspeed_config['bf16'][
'enabled']:
args.bf16 = True
else:
args.bf16 = False
if 'train_micro_batch_size_per_gpu' in deepspeed_config:
args.batch_size = deepspeed_config[
'train_micro_batch_size_per_gpu']
if 'gradient_accumulation_steps' in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config[
'gradient_accumulation_steps']
else:
args.gradient_accumulation_steps = None
if 'optimizer' in deepspeed_config:
optimizer_params_config = deepspeed_config['optimizer'].get(
'params', {})
args.lr = optimizer_params_config.get('lr', args.lr)
args.weight_decay = optimizer_params_config.get(
'weight_decay', args.weight_decay)
args.deepspeed_config = deepspeed_config
# initialize distributed and random seed because it always seems to be necessary.
initialize_distributed(args)
args.seed = args.seed + mpu.get_data_parallel_rank()
set_random_seed(args.seed)
return args
def initialize_distributed(args):
"""Initialize torch.distributed."""
if torch.distributed.is_initialized():
if mpu.model_parallel_is_initialized():
if args.model_parallel_size != mpu.get_model_parallel_world_size():
raise ValueError(
'model_parallel_size is inconsistent with prior configuration.'
'We currently do not support changing model_parallel_size.'
)
return False
else:
if args.model_parallel_size > 1:
warnings.warn(
'model_parallel_size > 1 but torch.distributed is not initialized via SAT.'
'Please carefully make sure the correctness on your own.')
mpu.initialize_model_parallel(args.model_parallel_size)
return True
# the automatic assignment of devices has been moved to arguments.py
if args.device == 'cpu':
pass
else:
torch.cuda.set_device(args.device)
# Call the init process
init_method = 'tcp://'
args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
if args.world_size == 1:
from sat.helpers import get_free_port
default_master_port = str(get_free_port())
else:
default_master_port = '6000'
args.master_port = os.getenv('MASTER_PORT', default_master_port)
init_method = None
#init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
timeout=datetime.timedelta(seconds=1200))
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
# Set vae context parallel group equal to model parallel group
from sgm.util import (initialize_context_parallel,
set_context_parallel_group)
if args.model_parallel_size <= 2:
set_context_parallel_group(args.model_parallel_size,
mpu.get_model_parallel_group())
else:
initialize_context_parallel(2)
# mpu.initialize_model_parallel(1)
# Optional DeepSpeed Activation Checkpointing Features
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(dist_backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
else:
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import (
_CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
1) # default seed 1
except Exception as e:
from sat.helpers import print_rank0
print_rank0(str(e), level='DEBUG')
return True
def process_config_to_args(args):
"""Fetch args from only --base"""
configs = [OmegaConf.load(cfg) for cfg in args.base]
config = OmegaConf.merge(*configs)
args_config = config.pop('args', OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
args_config[key], omegaconf.ListConfig):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]
if hasattr(args, key):
setattr(args, key, arg)
if 'model' in config:
model_config = config.pop('model', OmegaConf.create())
args.model_config = model_config
if 'deepspeed' in config:
deepspeed_config = config.pop('deepspeed', OmegaConf.create())
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
if 'data' in config:
data_config = config.pop('data', OmegaConf.create())
args.data_config = data_config
if 'img_data' in config:
img_data_config = config.pop('img_data', OmegaConf.create())
args.img_data_config = img_data_config
if 'video_data' in config:
video_data_config = config.pop('video_data', OmegaConf.create())
args.video_data_config = video_data_config
if 'trainable_params' in config:
trainable_params = config.pop('trainable_params', OmegaConf.create())
args.trainable_params_config = trainable_params
if 'share_cache_args' in config:
share_cache_args = config.pop('share_cache_args', OmegaConf.create())
args.share_cache_config = share_cache_args
if 'custom_args' in config:
custom_args = config.pop('custom_args', OmegaConf.create())
for k, v in custom_args.items():
setattr(args, k, v)
return args
'''
@File : base_model.py
@Time : 2021/10/01 22:40:33
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
import argparse
import inspect
import math
import os
import random
import sys
import warnings
# here put the import lib
from functools import partial
import torch
from base_transformer import GCBaseTransformer
from sat.arguments import (overwrite_args_by_dict, reset_random_seed,
set_random_seed, update_args_with_file)
from sat.helpers import print_rank0
from sat.model.mixins import BaseMixin
from sat.model.registry import MetaModel, model_registry
from sat.model.transformer import standard_attention
from sat.mpu.initialize import (destroy_model_parallel,
get_model_parallel_rank, get_node_rank,
initialize_model_parallel)
from sat.mpu.operation import (mp_merge_model_rank0, mp_merge_model_send,
mp_split_model_rank0, mp_split_model_receive)
from sat.resources import auto_create
from sat.training.model_io import load_checkpoint
from sat.transformer_defaults import ARGS_DEFAULT, HOOKS_DEFAULT
class BaseModel(torch.nn.Module, metaclass=MetaModel):
def __init__(self,
args,
transformer=None,
params_dtype=torch.float,
**kwargs):
super().__init__()
self.mixins = torch.nn.ModuleDict()
self.collect_hooks_()
if transformer is not None:
self.transformer = transformer
else:
# check if model-only mode
from sat.arguments import _simple_init
success = _simple_init(
model_parallel_size=args.model_parallel_size,
seed=args.seed if hasattr(args, 'seed') else 1234)
args_dict = {
k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1])
for k, v in ARGS_DEFAULT.items()
}
self.transformer = GCBaseTransformer(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
max_sequence_length=args.max_sequence_length,
layernorm_order=args.layernorm_order,
**args_dict,
hooks=self.hooks,
params_dtype=params_dtype,
skip_init=args.skip_init,
device=torch.cuda.current_device()
if hasattr(args, 'use_gpu_initialization')
and args.use_gpu_initialization else torch.device('cpu'),
**kwargs)
def reinit(self,
mixin_names=None
): # will be called when loading model, None means all
# if some mixins are loaded, overrides this function
for k, m in self.mixins.items():
if mixin_names is None or k in mixin_names:
m.reinit(self)
def add_mixin(self, name, new_mixin, reinit=False):
assert name not in self.mixins
assert isinstance(new_mixin, BaseMixin)
self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer',
self.transformer) # cannot use pytorch set_attr
self.collect_hooks_()
if reinit:
new_mixin.reinit(self) # also pass current mixins
def del_mixin(self, name):
assert name in self.mixins
del self.mixins[name]
self.collect_hooks_()
def get_mixin(self, name):
return self.mixins[name]
def forward(self, *args, **kwargs):
# update hooks as the current model (overrided forwards)
# Attention! the transformer might be shared by multiple models
self.transformer.hooks.clear()
self.transformer.hooks.update(self.hooks)
return self.transformer(*args, **kwargs)
def collect_hooks_(self):
names = list(HOOKS_DEFAULT.keys())
hooks = {}
hook_origins = {}
for name in names:
if hasattr(self, name):
hooks[name] = getattr(self, name)
hook_origins[name] = 'model'
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if hasattr(getattr(m, name), 'non_conflict'):
# check getattr(m, name), who must accept old_impl as an argument
signature = inspect.signature(getattr(m, name))
if 'old_impl' not in signature.parameters:
raise ValueError(
f'Hook {name} at {mixin_name} must accept old_impl as an argument.'
)
# -------------
if name in hooks:
old_impl = hooks[name]
elif name == 'attention_fn': # the only hook without self
old_impl = HOOKS_DEFAULT[name]
else:
old_impl = partial(
HOOKS_DEFAULT[name], self
) # relax! `partial` does not affect the signature
old_origin = hook_origins.get(name, 'default')
hooks[name] = partial(getattr(m, name),
old_impl=old_impl)
hook_origins[name] = mixin_name + ' -> ' + old_origin
elif name in hooks and not hasattr(
hooks[name], 'replacable'
): # if this hook name is already registered
raise ValueError(
f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.'
)
else: # new hook
if name in hooks and hasattr(hooks[name],
'replacable'):
warnings.warn(
f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.'
)
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
self.hooks = hooks
self.hook_origins = hook_origins
return hooks
def disable_untrainable_params(self):
pass
@classmethod
def add_model_specific_args(cls, parser):
# recorded in arguments.py: add_model_config_args
return parser
@classmethod
def from_pretrained_base(cls,
name,
args=None,
*,
home_path=None,
url=None,
prefix='',
build_only=False,
overwrite_args={},
**kwargs):
'''Load a pretrained checkpoint of the current model.
Args:
name: The identifier of the pretrained model.
args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
path: the parent folder of existing `name` model. Default: SAT_HOME.
url: the url of the model. Default: SAT_URL.
prefix: the prefix of the checkpoint. Default: ''.
Returns:
model: the loaded model.
args: the loaded args.
'''
if os.path.exists(name) and os.path.isdir(name):
model_path = name
else:
model_path = auto_create(name, path=home_path, url=url)
# create a new args if not provided
if args is None:
args = cls.get_args()
args = update_args_with_file(args,
path=os.path.join(model_path,
'model_config.json'))
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
specific_iteration = kwargs.pop('specific_iteration', None)
model = get_model(args, cls, **kwargs)
if not build_only:
load_checkpoint(model,
args,
load_path=model_path,
prefix=prefix,
specific_iteration=specific_iteration)
return model, args
@classmethod
def from_pretrained(cls,
name,
args=None,
*,
home_path=None,
url=None,
prefix='',
build_only=False,
use_node_group=True,
overwrite_args={},
**kwargs):
if build_only or 'model_parallel_size' not in overwrite_args:
return cls.from_pretrained_base(name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=build_only,
overwrite_args=overwrite_args,
**kwargs)
else:
new_model_parallel_size = overwrite_args['model_parallel_size']
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
model, model_args = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=True,
overwrite_args=overwrite_args,
**kwargs)
local_rank = get_node_rank(
) if use_node_group else get_model_parallel_rank()
world_size = torch.distributed.get_world_size()
assert world_size % new_model_parallel_size == 0, 'world size should be a multiplier of new model_parallel_size.'
destroy_model_parallel()
initialize_model_parallel(1)
if local_rank == 0:
args.skip_init = True
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args.pop('model_parallel_size')
model_full, args_ = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=False,
overwrite_args=overwrite_args,
**kwargs)
if args_.model_parallel_size != 1:
raise Exception(
"We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!"
)
if hasattr(
args, 'mode'
) and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(new_model_parallel_size)
if local_rank == 0:
mp_split_model_rank0(model,
model_full,
use_node_group=use_node_group)
del model_full
else:
mp_split_model_receive(model,
use_node_group=use_node_group)
reset_random_seed(6)
else:
overwrite_args.pop('model_parallel_size')
model, model_args = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=False,
overwrite_args=overwrite_args,
**kwargs)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert world_size == model_args.model_parallel_size, 'world size should be equal to model_parallel_size.'
destroy_model_parallel()
initialize_model_parallel(1)
if rank == 0:
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args['model_parallel_size'] = 1
model_full, args_ = cls.from_pretrained_base(
name,
args=args,
home_path=home_path,
url=url,
prefix=prefix,
build_only=True,
overwrite_args=overwrite_args,
**kwargs)
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(model_args.model_parallel_size)
if rank == 0:
mp_merge_model_rank0(model, model_full)
model, model_args = model_full, args_
else:
mp_merge_model_send(model)
model_args.model_parallel_size = 1
destroy_model_parallel()
initialize_model_parallel(1)
return model, model_args
@classmethod
def list_avail_args(cls, print=True):
'''List all available args of the current model.'''
parser = argparse.ArgumentParser()
from sat.arguments import add_model_config_args
add_model_config_args(parser)
# add args of the current model
if hasattr(cls, 'add_model_specific_args'):
cls.add_model_specific_args(parser)
if print:
from sat.helpers import print_parser
print_parser(parser)
return parser
@classmethod
def get_args(cls, **kwargs):
'''Get the parsed args of the current model.
Args:
**kwargs: will override the default args.
Returns:
args: the parsed args.
'''
parser = cls.list_avail_args(print=False)
# use parser to parse kwargs
args = parser.parse_args([])
for k, v in kwargs.items():
if hasattr(args, k) or k in [
'fp16'
]: # non-arch args but affect building models
setattr(args, k, v)
else:
print_rank0(
f'warning: Unknown arg {k} for class {cls.__name__}.',
level='DEBUG')
setattr(args, k, v)
return args
# rewritten, Copyright (c) 2021, Ming Ding. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import copy
import math
import torch
import torch.nn.functional as F
from deepspeed.runtime.activation_checkpointing.checkpointing import \
non_reentrant_checkpoint as checkpoint
from sat import mpu
from sat.model.transformer import BaseTransformerLayer
from sat.mpu import (ColumnParallelLinear, RowParallelLinear,
VocabParallelEmbedding, copy_to_model_parallel_region,
gather_from_model_parallel_region,
get_model_parallel_world_size)
from sat.mpu.utils import (divide, gelu, scaled_init_method, sqrt,
unscaled_init_method)
from sat.ops.layernorm import LayerNorm
from sat.transformer_defaults import (HOOKS_DEFAULT,
split_tensor_along_last_dim,
standard_attention)
# checkpoint
class GCBaseTransformer(torch.nn.Module):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length,
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path=0,
checkpoint_activations=False,
checkpoint_num_layers=1,
checkpoint_skip_layers=0,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
layernorm_order='pre',
parallel_output=False,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
activation_func=gelu,
is_gated_mlp=False,
is_rotary_emb=False,
num_experts=1,
layernorm=LayerNorm,
init_method=None,
use_final_layernorm=True,
hooks={},
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')):
super().__init__()
# recording parameters
self.hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
self.is_decoder = is_decoder
self.cross_attn_hidden_size = cross_attn_hidden_size
self.cross_num_multi_query_heads = cross_num_multi_query_heads
if not is_decoder and cross_attn_hidden_size is not None:
print(
'warning: cross_attn_hidden_size is set but is_decoder is False'
)
self.use_bias = use_bias
self.use_qkv_bias = use_qkv_bias
self.num_multi_query_heads = num_multi_query_heads
self.is_gated_mlp = is_gated_mlp
self.is_rotary_emb = is_rotary_emb
self.num_experts = num_experts
self.use_final_layernorm = use_final_layernorm
self.layernorm_epsilon = layernorm_epsilon
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.checkpoint_skip_layers = checkpoint_skip_layers
assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
self.max_sequence_length = max_sequence_length
self.layernorm_order = layernorm_order
self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
self.hooks = copy.copy(hooks) # hooks will be updated each forward
object.__setattr__(
self, 'transformer',
self) # to give the default hooks the same api as outer hooks
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if vocab_size < 1000:
self.word_embeddings = torch.nn.Embedding(vocab_size,
hidden_size,
dtype=params_dtype,
device=device)
torch.nn.init.normal_(self.word_embeddings.weight,
mean=0.0,
std=init_method_std)
else:
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
params_dtype=params_dtype,
skip_init=skip_init,
device=device)
if self.is_rotary_emb:
from sat.model.position_embedding.triton_rotary_embeddings import \
FastRotaryEmbedding
self.position_embeddings = FastRotaryEmbedding(hidden_size //
num_attention_heads)
else:
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight,
mean=0.0,
std=init_method_std)
# create all layers
if init_method is None:
self.output_layer_init_method = scaled_init_method(
init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
cross_hidden_size_per_attention_head=
cross_hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
cross_attn_hidden_size=cross_attn_hidden_size,
layernorm_order=layernorm_order,
layernorm=layernorm,
use_bias=use_bias,
use_qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
drop_path=drop_path,
activation_func=activation_func,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
hooks=self.hooks,
transformer_pointer=self,
params_dtype=params_dtype,
skip_init=skip_init,
device=device)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
if use_final_layernorm:
self.final_layernorm = layernorm(hidden_size,
eps=layernorm_epsilon)
def forward(self,
input_ids,
position_ids,
attention_mask,
*,
output_hidden_states=False,
**kw_args):
# sanity check
assert len(input_ids.shape) >= 2
batch_size, query_length = input_ids.shape[:2]
if attention_mask is None:
# Definition: None means full attention
attention_mask = torch.ones(1, 1, device=input_ids.device)
elif isinstance(attention_mask, int) and (attention_mask < 0):
# Definition: -1 means lower triangular attention mask
attention_mask = torch.ones(query_length,
query_length,
device=input_ids.device).tril()
attention_mask = attention_mask.type_as(next(self.parameters()))
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# initial output_cross_layer might be generated by word/position_embedding_forward
output_cross_layer = {}
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](
input_ids, output_cross_layer=output_cross_layer, **kw_args)
else: # default
hidden_states = HOOKS_DEFAULT['word_embedding_forward'](
self,
input_ids,
output_cross_layer=output_cross_layer,
**kw_args)
# handle position embedding
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](
position_ids, output_cross_layer=output_cross_layer, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == hidden_states.shape[1], (
position_ids.shape, hidden_states.shape)
position_embeddings = HOOKS_DEFAULT['position_embedding_forward'](
self,
position_ids,
output_cross_layer=output_cross_layer,
**kw_args)
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
output_per_layers = []
if self.checkpoint_activations:
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_,
mask,
layer_id=layer.layer_id,
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
else:
layer_ret = layer(
x_,
mask,
layer_id=layer.layer_id,
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return (x_, output_per_layers_part, output_cross_layer,
*flat_outputs)
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
if l + self.checkpoint_skip_layers >= num_layers:
# no checkpointing
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs)
else:
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[
output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](
*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
else:
layer_ret = layer(
*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
if self.use_final_layernorm:
logits = self.final_layernorm(hidden_states)
else:
logits = hidden_states
logits = copy_to_model_parallel_region(logits)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](
logits, **kw_args, parallel_output=self.parallel_output)
else:
logits_parallel = HOOKS_DEFAULT['final_forward'](
self, logits, **kw_args, parallel_output=self.parallel_output)
outputs = [logits_parallel]
outputs.extend(output_per_layers)
return outputs
share_cache_args:
disable_ref : True
num_vis_img: 4
vis_ddpm: True
eval_interval_list: [1, 50, 100, 1000]
save_interval_list: [2000]
args:
checkpoint_activations: False # using gradient checkpointing
model_parallel_size: 1
experiment_name: lora-disney
mode: finetune
load: ""
no_load_rng: True
train_iters: 10000000000 # Suggest more than 1000 For Lora and SFT For 500 is enough
eval_iters: 100000000
eval_interval: 1000000000000
eval_batch_size: 1
save: ./
save_interval: 1000
log_interval: 20
train_data: [ "disney" ] # Train data path
valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
split: 1,0,0
num_workers: 2
force_train: True
only_log_video_latents: True
deepspeed:
# Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 1
steps_per_print: 50
gradient_clipping: 0.1
zero_optimization:
stage: 2
cpu_offload: true
contiguous_gradients: false
overlap_comm: true
reduce_scatter: true
reduce_bucket_size: 1000000000
allgather_bucket_size: 1000000000
load_from_fp32_weights: false
zero_allow_untested_optimizer: true
bf16:
enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
loss_scale: 0
loss_scale_window: 400
hysteresis: 2
min_loss_scale: 1
model:
scale_factor: 0.7
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
network_config:
target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 64
time_compressed_rate: 4
# latent_width: 90
# latent_height: 60
latent_width: 512
latent_height: 512
num_layers: 42 # different from cogvideox_2b_infer.yaml
patch_size: 2
in_channels: 16
out_channels: 16
hidden_size: 3072 # different from cogvideox_2b_infer.yaml
adm_in_channels: 256
num_attention_heads: 48 # different from cogvideox_2b_infer.yaml
transformer_args:
checkpoint_activations: False
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
params:
hidden_size_head: 64
text_length: 226
lora_config:
target: extra_models.dit_res_adapter.ResLoraMixin
params:
r: 128
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "checkpoints/3d-vae.pt"
ignore_keys: [ 'loss' ]
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: False
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
num_steps: 51
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
# TODO check this cfg
scale: 8
exp: 5
num_steps: 51
custom_args:
reload: ""
share_cache_args:
sample_ref_noise_step: 675
time_size_embedding: True
args:
checkpoint_activations: True # using gradient checkpointing
model_parallel_size: 1
experiment_name: lora-disney
mode: finetune
load: "" # This is for Full model without lora adapter " # This is for Full model without lora adapter
no_load_rng: True
train_iters: 100000 # Suggest more than 1000 For Lora and SFT For 500 is enough
eval_iters: 100000000
eval_interval: [1, 200]
eval_batch_size: 1
save:
# for debug
save_interval: 250
log_interval: 5
train_data: [ "disney" ] # Train data path
valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended)
split: 1,0,0
num_workers: 1
force_train: True
only_log_video_latents: True
deepspeed:
# Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
train_micro_batch_size_per_gpu: 1
gradient_accumulation_steps: 1
steps_per_print: 50
gradient_clipping: 0.1
zero_optimization:
stage: 2
cpu_offload: true
contiguous_gradients: false
overlap_comm: true
reduce_scatter: true
reduce_bucket_size: 1000000000
allgather_bucket_size: 1000000000
load_from_fp32_weights: false
zero_allow_untested_optimizer: true
bf16:
enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
loss_scale: 0
loss_scale_window: 400
hysteresis: 2
min_loss_scale: 1
model:
scale_factor: 1.15258426
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 3.0 # different from cogvideox_2b_infer.yaml
network_config:
target: extra_models.dit_res_adapter.SMALLDiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 64
time_compressed_rate: 4
# latent_width: 90
# latent_height: 60
latent_width: 512
latent_height: 512
num_layers: 30 # different from cogvideox_2b_infer.yaml
patch_size: 2
in_channels: 16
out_channels: 16
hidden_size: 1920 # different from cogvideox_2b_infer.yaml
adm_in_channels: 256
num_attention_heads: 30 # different from cogvideox_2b_infer.yaml
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: extra_models.dit_res_adapter.ScaleCropRotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
params:
hidden_size_head: 64
text_length: 226
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "checkpoints/3d-vae.pt"
ignore_keys: [ 'loss' ]
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: False
loss_fn_config:
target: flow_video.FlowVideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: False
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
sampler_config:
target: sgm.modules.diffusionmodules.sampling.CascadeVPSDEDPMPP2MSampler
params:
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Stage-I ckpt path \n",
"stage_1_path = \"./checkpoints/stage1.pt\"\n",
"stage_2_path = \"./checkpoints/stage2.pt\"\n",
"save_dir=\"vis_270p_1080p\"\n",
"# 2 ~ 3\n",
"shift_t = 2.5\n",
"# 4 ~ 6\n",
"sample_step = 5\n",
"# 10 ~ 13\n",
"cfg_second = 13\n",
" # 650 ~ 750\n",
"deg_latent_strength=675\n",
"# stage_1_hw \n",
"\n",
"#TODO Stage I CFG here\n",
"cfg_first = 8\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"current_directory = os.getcwd()\n",
"os.chdir(os.path.dirname(current_directory))\n",
"new_directory = os.getcwd()\n",
"print(f\"working directory: {new_directory}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import os\n",
"import argparse\n",
"import torch\n",
"import numpy as np\n",
"import copy\n",
"\n",
"from sat.model.base_model import get_model\n",
"from arguments import get_args\n",
"from torchvision.io.video import write_video\n",
"\n",
"from flow_video import FlowEngine\n",
"from diffusion_video import SATVideoDiffusionEngine\n",
"\n",
"import os\n",
"from utils import disable_all_init, decode, prepare_input, save_memory_encode_first_stage, save_mem_decode, seed_everything\n",
"disable_all_init()\n",
"\n",
"\n",
"def init_model(model, second_model, args, second_args):\n",
" share_cache = dict()\n",
" second_share_cache = dict()\n",
" if hasattr(args, 'share_cache_config'):\n",
" for k, v in args.share_cache_config.items():\n",
" share_cache[k] = v\n",
" if hasattr(second_args, 'share_cache_config'):\n",
" for k, v in second_args.share_cache_config.items():\n",
" second_share_cache[k] = v\n",
"\n",
" for n, m in model.named_modules():\n",
" m.share_cache = share_cache \n",
" if hasattr(m, \"register_new_modules\"):\n",
" m.register_new_modules()\n",
" for n, m in second_model.named_modules():\n",
" m.share_cache = second_share_cache \n",
" if hasattr(m, \"register_new_modules\"):\n",
" m.register_new_modules() \n",
"\n",
" weight_path = args.inf_ckpt\n",
" weight = torch.load(weight_path, map_location=\"cpu\")\n",
" if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in weight[\"module\"]:\n",
" del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
" del weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
" msg = model.load_state_dict(weight[\"module\"], strict=False)\n",
" print(msg)\n",
" second_weight_path = args.inf_ckpt2\n",
" second_weight = torch.load(second_weight_path, map_location=\"cpu\")\n",
" if \"model.diffusion_model.mixins.pos_embed.freqs_sin\" in second_weight[\"module\"]:\n",
" del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_sin\"]\n",
" del second_weight[\"module\"][\"model.diffusion_model.mixins.pos_embed.freqs_cos\"]\n",
" second_msg = second_model.load_state_dict(second_weight[\"module\"], strict=False)\n",
" print(second_msg)\n",
"\n",
"def get_first_results(model, text, num_frames, H, W, neg_prompt=None):\n",
" \"\"\"Get first Stage results.\n",
"\n",
" Args:\n",
" model (nn.Module): first stage model.\n",
" text (str): text prompt\n",
" num_frames (int): number of frames\n",
" H (int): height of the first stage results\n",
" W (int): width of the first stage results\n",
" neg_prompt (str): negative prompt\n",
"\n",
" Returns:\n",
" Tensor: first stage video.\n",
" \"\"\"\n",
" device = 'cuda'\n",
" T = 1 + (num_frames - 1) // 4\n",
" F = 8\n",
" motion_text_prefix = [\n",
" 'very low motion,',\n",
" 'low motion,',\n",
" 'medium motion,',\n",
" 'high motion,',\n",
" 'very high motion,',\n",
" ]\n",
" pos_prompt = \"\"\n",
" if neg_prompt is None:\n",
" neg_prompt = \"\"\n",
" with torch.no_grad():\n",
" model.to('cuda')\n",
" input_negative_prompt = motion_text_prefix[\n",
" 0] + ', ' + motion_text_prefix[1] + neg_prompt\n",
" c, uc = prepare_input(text,\n",
" model,\n",
" T,\n",
" negative_prompt=input_negative_prompt,\n",
" pos_prompt=pos_prompt)\n",
" with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
" device_type='cuda',\n",
" dtype=torch.bfloat16):\n",
" samples_z = model.sample(\n",
" c,\n",
" uc=uc,\n",
" batch_size=1,\n",
" shape=(T, 16, H // F, W // F),\n",
" num_steps=model.share_cache.get('first_sample_step', None),\n",
" )\n",
" samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" model.to('cpu')\n",
" torch.cuda.empty_cache()\n",
" first_stage_model = model.first_stage_model\n",
" first_stage_model = first_stage_model.to(device)\n",
"\n",
" latent = 1.0 / model.scale_factor * samples_z\n",
"\n",
" samples = decode(first_stage_model, latent)\n",
" model.to('cpu')\n",
" return samples\n",
"def get_second_results(model, text, first_stage_samples, num_frames):\n",
" \"\"\"Get second Stage results.\n",
"\n",
" Args:\n",
" model (nn.Module): second stage model.\n",
" text (str): text prompt\n",
" first_stage_samples (Tensor): first stage results\n",
" num_frames (int): number of frames\n",
" Returns:\n",
" Tensor: second stage results.\n",
" \"\"\"\n",
"\n",
" t, h, w, c = first_stage_samples.shape\n",
" first_stage_samples = first_stage_samples[:num_frames]\n",
" first_stage_samples = (first_stage_samples / 255.)\n",
" first_stage_samples = (first_stage_samples - 0.5) / 0.5\n",
"\n",
" target_size = model.share_cache.get('target_size', None)\n",
" if target_size is None:\n",
" upscale_factor = model.share_cache.get('upscale_factor', 8)\n",
" H = int(h * upscale_factor) // 16 * 16\n",
" W = int(w * upscale_factor) // 16 * 16\n",
" else:\n",
" H, W = target_size\n",
" H = H // 16 * 16\n",
" W = W // 16 * 16\n",
"\n",
" first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')\n",
"\n",
" ref_x = torch.nn.functional.interpolate(first_stage_samples,\n",
" size=(H, W),\n",
" mode='bilinear',\n",
" align_corners=False,\n",
" antialias=True)\n",
" ref_x = ref_x[:num_frames][None]\n",
"\n",
" ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" first_stage_model = model.first_stage_model\n",
" print(f'start encoding first stage results to high resolution')\n",
" with torch.no_grad():\n",
" first_stage_dtype = next(model.first_stage_model.parameters()).dtype\n",
" model.first_stage_model.cuda()\n",
" ref_x = save_memory_encode_first_stage(\n",
" ref_x.contiguous().to(first_stage_dtype).cuda(), model)\n",
"\n",
" ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()\n",
" ref_x = ref_x.to(model.dtype)\n",
" print(f'finish encoding first stage results, and starting stage II')\n",
"\n",
" device = 'cuda'\n",
"\n",
" model.to(device)\n",
"\n",
" pos_prompt = ''\n",
" input_negative_prompt = \"\"\n",
"\n",
" c, uc = prepare_input(text,\n",
" model,\n",
" num_frames,\n",
" negative_prompt=input_negative_prompt,\n",
" pos_prompt=pos_prompt)\n",
"\n",
" T = 1 + (num_frames - 1) // 4\n",
" F = 8\n",
" with torch.no_grad(), torch.amp.autocast(enabled=True,\n",
" device_type='cuda',\n",
" dtype=torch.bfloat16):\n",
" samples_z = model.sample(\n",
" ref_x,\n",
" c,\n",
" uc=uc,\n",
" batch_size=1,\n",
" shape=(T, 16, H // F, W // F),\n",
" num_steps=model.share_cache.get('sample_step', 5),\n",
" method='euler',\n",
" cfg=model.share_cache.get('cfg', 7.5),\n",
" )\n",
" samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()\n",
"\n",
" model.to('cpu')\n",
" torch.cuda.empty_cache()\n",
" first_stage_model = model.first_stage_model\n",
" first_stage_model = first_stage_model.to(device)\n",
"\n",
" latent = 1.0 / model.scale_factor * samples_z\n",
" print(f'start spatiotemporal slice decoding')\n",
" samples = save_mem_decode(first_stage_model, latent)\n",
" print(f'finish spatiotemporal slice decoding')\n",
" model.to('cpu')\n",
" return samples\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"\n",
"os.environ[\"LOCAL_RANK\"] = \"0\"\n",
"os.environ[\"WORLD_SIZE\"] = \"1\"\n",
"os.environ[\"RANK\"] = \"0\"\n",
"os.environ[\"MASTER_ADDR\"] = \"0.0.0.0\"\n",
"os.environ[\"MASTER_PORT\"] = \"12345\"\n",
"\n",
"py_parser = argparse.ArgumentParser(add_help=False)\n",
"args_list = [\n",
" \"--base\", \"flashvideo/configs/stage1.yaml\",\n",
" \"--second\", \"flashvideo/configs/stage2.yaml\",\n",
" \"--inf-ckpt\", stage_1_path,\n",
" \"--inf-ckpt2\", stage_2_path,\n",
"]\n",
"known, args_list = py_parser.parse_known_args(args=args_list)\n",
"second_args_list = copy.deepcopy(args_list)\n",
"\n",
"\n",
"args = get_args(args_list)\n",
"args = argparse.Namespace(**vars(args), **vars(known))\n",
"del args.deepspeed_config\n",
"args.model_config.first_stage_config.params.cp_size = 1\n",
"args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
"args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
"args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n",
"\n",
"second_args_list[1] = args.second[0]\n",
"second_args = get_args(second_args_list)\n",
"second_args = argparse.Namespace(**vars(second_args), **vars(known))\n",
"del second_args.deepspeed_config\n",
"second_args.model_config.first_stage_config.params.cp_size = 1\n",
"second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1\n",
"second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False\n",
"second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_cls=SATVideoDiffusionEngine\n",
"second_model_cls=FlowEngine\n",
"local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n",
"torch.cuda.set_device(local_rank)\n",
"\n",
"second_model = get_model(second_args, second_model_cls)\n",
"\n",
"model = get_model(args, model_cls)\n",
" \n",
"init_model(model, second_model, args, second_args )\n",
" \n",
"model.eval()\n",
"second_model.eval()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for n, m in model.named_modules():\n",
" if hasattr(m, \"merge_lora\"):\n",
" m.merge_lora()\n",
" print(f\"merge lora of {n}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_frames = 49\n",
"second_num_frames = 49 \n",
"\n",
"stage_1_hw = (270, 480) \n",
"stage_2_hw = (1080, 1920) \n",
"\n",
"# make sure all can be divided by 16\n",
"stage_1_hw = (stage_1_hw[0] // 16 * 16, stage_1_hw[1] // 16 * 16)\n",
"stage_2_hw = (stage_2_hw[0] // 16 * 16, stage_2_hw[1] // 16 * 16)\n",
"\n",
"sample_func = model.sample\n",
"T, H, W, C, F = num_frames, stage_1_hw[0], stage_1_hw[1], args.latent_channels, 8\n",
"S_T, S_H, S_W, S_C, S_F = second_num_frames, stage_2_hw[0], stage_2_hw[1], args.latent_channels, 8\n",
"\n",
"\n",
" \n",
"seed_everything(0)\n",
"\n",
"text = \" Sunny day, The camera smoothly pushes in through an ornate garden archway, delicately adorned with climbing ivy. \\\n",
" Beyond the archway, a secret, tranquil garden is revealed, brimming with a vibrant array of blooming flowers \\\n",
" in a myriad of colors. A beautiful young woman with long wavy brown hair, she is smile to the camera , \\\n",
" wearing a red hat sits holding a dog , the red hat has rich fabric texture \\\n",
" wearing black pleated skirt and yellow sweater \"\n",
"\n",
"\n",
"neg_text = \"\"\n",
"\n",
"if os.path.exists(save_dir) is False:\n",
" os.makedirs(save_dir)\n",
"enu_index = \"1\"\n",
"model.share_cache[\"cfg\"] = cfg_first\n",
"\n",
"first_stage_samples = get_first_results(model, text, num_frames, H, W, neg_text)\n",
"\n",
"print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}.mp4\")\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}.mp4', \n",
" fps=8, \n",
" video_array= first_stage_samples, \n",
" options = { 'crf': '14' })\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"second_num_frames = 49\n",
"second_model.share_cache[\"ref_noise_step\"] = deg_latent_strength\n",
"second_model.share_cache[\"sample_ref_noise_step\"] = deg_latent_strength\n",
"second_model.share_cache.pop(\"ref_noise_step_range\", None)\n",
"second_model.share_cache[\"target_size\"] = stage_2_hw\n",
"second_model.share_cache[\"shift_t\"] = shift_t\n",
"second_model.share_cache[\"sample_step\"] = sample_step\n",
"second_model.share_cache[\"cfg\"] = cfg_second\n",
"post_fix = f'''noise_{second_model.share_cache[\"ref_noise_step\"]}_step_{second_model.share_cache[\"sample_step\"]}_cfg_{second_model.share_cache[\"cfg\"]}_shift_{second_model.share_cache[\"shift_t\"]}_size_{stage_2_hw[0]}x{stage_2_hw[1]}'''\n",
"second_model.share_cache[\"time_size_embedding\"] = True\n",
"second_stage_samples = get_second_results(second_model, \n",
" text, \n",
" first_stage_samples, \n",
" second_num_frames)\n",
"\n",
"print(f\"save to {save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}.mp4\")\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_second.mp4', \n",
" fps=8, \n",
" video_array= second_stage_samples.cpu(), \n",
" options = { 'crf': '14' })\n",
"\n",
"\n",
"# save joint video \n",
"part_first_stage = first_stage_samples[:second_num_frames]\n",
"\n",
"target_h, target_w = second_stage_samples.shape[1], second_stage_samples.shape[2]\n",
"part_first_stage = torch.nn.functional.interpolate(part_first_stage.permute(0, 3, 1, 2).contiguous(),\n",
" size=(target_h, target_w),\n",
" mode=\"bilinear\",\n",
" align_corners=False, \n",
" antialias=True)\n",
"part_first_stage = part_first_stage.permute(0, 2, 3, 1).contiguous()\n",
"\n",
"\n",
"joint_video = torch.cat([part_first_stage.cpu(), second_stage_samples.cpu()], dim=-2)\n",
"print(f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4')\n",
"write_video(filename=f'./{save_dir}/{enu_index}_num_frame_{num_frames}_{post_fix}_joint.mp4',\n",
" fps=8,\n",
" video_array=joint_video.cpu(),\n",
" options={'crf': '15'}) \n"
]
}
],
"metadata": {
"fileId": "c6eed2be-3101-492e-a984-783ecbc70a34",
"filePath": "/mnt/bn/foundation-ads/shilong/conda/code/cogvideo-5b/sat/demo_ab.ipynb",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This diff is collapsed.
import argparse
import copy
import os
import numpy as np
import torch
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from flow_video import FlowEngine
from torchvision.io.video import write_video
from utils import (decode, disable_all_init, prepare_input, save_mem_decode,
save_memory_encode_first_stage, seed_everything)
from sat import mpu
from sat.model.base_model import get_model
disable_all_init()
def init_model(model, second_model, args, second_args):
share_cache = dict()
second_share_cache = dict()
if hasattr(args, 'share_cache_config'):
for k, v in args.share_cache_config.items():
share_cache[k] = v
if hasattr(second_args, 'share_cache_config'):
for k, v in second_args.share_cache_config.items():
second_share_cache[k] = v
for n, m in model.named_modules():
m.share_cache = share_cache
if hasattr(m, 'register_new_modules'):
m.register_new_modules()
for n, m in second_model.named_modules():
m.share_cache = second_share_cache
if hasattr(m, 'register_new_modules'):
m.register_new_modules()
if os.environ.get('SKIP_LOAD', None) is not None:
print('skip load for speed debug')
else:
weight_path = args.inf_ckpt
weight = torch.load(weight_path, map_location='cpu')
if 'model.diffusion_model.mixins.pos_embed.freqs_sin' in weight[
'module']:
del weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_sin']
del weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_cos']
msg = model.load_state_dict(weight['module'], strict=False)
print(msg)
second_weight_path = args.inf_ckpt2
second_weight = torch.load(second_weight_path, map_location='cpu')
if 'model.diffusion_model.mixins.pos_embed.freqs_sin' in second_weight[
'module']:
del second_weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_sin']
del second_weight['module'][
'model.diffusion_model.mixins.pos_embed.freqs_cos']
second_msg = second_model.load_state_dict(second_weight['module'],
strict=False)
print(second_msg)
for n, m in model.named_modules():
if hasattr(m, 'merge_lora'):
m.merge_lora()
print(f'merge lora of {n}')
def get_first_results(model, text, num_frames, H, W):
"""Get first Stage results.
Args:
model (nn.Module): first stage model.
text (str): text prompt
num_frames (int): number of frames
H (int): height of the first stage results
W (int): width of the first stage results
Returns:
Tensor: first stage video.
"""
device = 'cuda'
T = 1 + (num_frames - 1) // 4
F = 8
motion_text_prefix = [
'very low motion,',
'low motion,',
'medium motion,',
'high motion,',
'very high motion,',
]
neg_prompt = ''
pos_prompt = ''
with torch.no_grad():
model.to('cuda')
input_negative_prompt = motion_text_prefix[
0] + ', ' + motion_text_prefix[1] + neg_prompt
c, uc = prepare_input(text,
model,
T,
negative_prompt=input_negative_prompt,
pos_prompt=pos_prompt)
with torch.no_grad(), torch.amp.autocast(enabled=True,
device_type='cuda',
dtype=torch.bfloat16):
samples_z = model.sample(
c,
uc=uc,
batch_size=1,
shape=(T, 16, H // F, W // F),
num_steps=model.share_cache.get('first_sample_step', None),
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
model.to('cpu')
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
samples = decode(first_stage_model, latent)
model.to('cpu')
return samples
def get_second_results(model, text, first_stage_samples, num_frames):
"""Get second Stage results.
Args:
model (nn.Module): second stage model.
text (str): text prompt
first_stage_samples (Tensor): first stage results
num_frames (int): number of frames
Returns:
Tensor: second stage results.
"""
t, h, w, c = first_stage_samples.shape
first_stage_samples = first_stage_samples[:num_frames]
first_stage_samples = (first_stage_samples / 255.)
first_stage_samples = (first_stage_samples - 0.5) / 0.5
target_size = model.share_cache.get('target_size', None)
if target_size is None:
upscale_factor = model.share_cache.get('upscale_factor', 8)
H = int(h * upscale_factor) // 16 * 16
W = int(w * upscale_factor) // 16 * 16
else:
H, W = target_size
H = H // 16 * 16
W = W // 16 * 16
first_stage_samples = first_stage_samples.permute(0, 3, 1, 2).to('cuda')
ref_x = torch.nn.functional.interpolate(first_stage_samples,
size=(H, W),
mode='bilinear',
align_corners=False,
antialias=True)
ref_x = ref_x[:num_frames][None]
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
first_stage_model = model.first_stage_model
print(f'start encoding first stage results to high resolution')
with torch.no_grad():
first_stage_dtype = next(model.first_stage_model.parameters()).dtype
model.first_stage_model.cuda()
ref_x = save_memory_encode_first_stage(
ref_x.contiguous().to(first_stage_dtype).cuda(), model)
ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous()
ref_x = ref_x.to(model.dtype)
print(f'finish encoding first stage results, and starting stage II')
device = 'cuda'
model.to(device)
motion_text_prefix = [
'very low motion,',
'low motion,',
'medium motion,',
'high motion,',
'very high motion,',
]
pos_prompt = None
input_negative_prompt = None
text = 'medium motion,' + text
c, uc = prepare_input(text,
model,
num_frames,
negative_prompt=input_negative_prompt,
pos_prompt=pos_prompt)
T = 1 + (num_frames - 1) // 4
F = 8
with torch.no_grad(), torch.amp.autocast(enabled=True,
device_type='cuda',
dtype=torch.bfloat16):
samples_z = model.sample(
ref_x,
c,
uc=uc,
batch_size=1,
shape=(T, 16, H // F, W // F),
num_steps=model.share_cache.get('sample_step', 5),
method='euler',
cfg=model.share_cache.get('cfg', 7.5),
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
model.to('cpu')
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
print(f'start spatiotemporal slice decoding')
samples = save_mem_decode(first_stage_model, latent)
print(f'finish spatiotemporal slice decoding')
model.to('cpu')
return samples
def sampling_main(args, second_args, model_cls, second_model_cls):
local_rank = int(os.environ.get('LOCAL_RANK', 0))
torch.cuda.set_device(local_rank)
second_model = get_model(second_args, second_model_cls)
model = get_model(args, model_cls)
init_model(model, second_model, args, second_args)
model.eval()
second_model.eval()
rank, world_size = mpu.get_data_parallel_rank(
), mpu.get_data_parallel_world_size()
rank, world_size = mpu.get_data_parallel_rank(
), mpu.get_data_parallel_world_size()
print('rank and world_size', rank, world_size)
text_file = args.input_file
num_sample_perprompt = 1
with open(text_file) as fin:
all_prompt = []
for single_line in fin:
all_prompt.append(single_line.strip())
print(f'load from {text_file} with {len(all_prompt)}')
image_size = [270, 480]
image_size = (image_size[0] // 16 * 16, image_size[1] // 16 * 16)
# second_img_size = [1080, 1920]
second_img_size = [270, 480]
second_img_size = (second_img_size[0] // 16 * 16,
second_img_size[1] // 16 * 16)
num_frames = 49
second_num_frames = 49
# 6-8
model.share_cache['cfg'] = 8
second_model.share_cache['target_size'] = second_img_size
# range from 650 to 750
second_model.share_cache['ref_noise_step'] = 675
second_model.share_cache['sample_ref_noise_step'] = 675
# range from 2 to 3.5
second_model.share_cache['shift_t'] = 2.5
# range from 4 to 6
second_model.share_cache['sample_step'] = 5
# range from 10 to 13
second_model.share_cache['cfg'] = 13
second_model.share_cache.pop('ref_noise_step_range', None)
second_model.share_cache['time_size_embedding'] = True
_, H, W, _, _ = num_frames, image_size[0], image_size[
1], args.latent_channels, 8
save_dir = f''''''
if args.output_dir:
save_dir = args.output_dir
print(save_dir)
os.makedirs(save_dir, exist_ok=True)
all_ids = range(len(all_prompt))
local_ids = all_ids[rank::world_size]
for enu_index in local_ids:
text = all_prompt[enu_index]
print(f'rank {rank} processing {enu_index}')
for inter_index in range(num_sample_perprompt):
seed_everything(enu_index + inter_index * 1000)
seed = enu_index + inter_index * 1000
first_stage_samples = get_first_results(model, text, num_frames, H,
W)
file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}.mp4'
second_file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}_second.mp4'
joint_file_name = f'{save_dir}/{enu_index}_{inter_index}_seed_{seed}_joint.mp4'
print(f'save to {file_name}')
write_video(filename=file_name,
fps=8,
video_array=first_stage_samples,
options={'crf': '5'})
if not args.skip_second:
second_stage_samples = get_second_results(
second_model, text, first_stage_samples, second_num_frames)
write_video(filename=second_file_name,
fps=8,
video_array=second_stage_samples.cpu(),
options={'crf': '5'})
# save joint video
part_first_stage = first_stage_samples[:second_num_frames]
target_h, target_w = second_stage_samples.shape[
1], second_stage_samples.shape[2]
part_first_stage = torch.nn.functional.interpolate(
part_first_stage.permute(0, 3, 1, 2).contiguous(),
size=(target_h, target_w),
mode='bilinear',
align_corners=False,
antialias=True)
part_first_stage = part_first_stage.permute(0, 2, 3,
1).contiguous()
joint_video = torch.cat(
[part_first_stage.cpu(),
second_stage_samples.cpu()],
dim=-2)
write_video(filename=joint_file_name,
fps=8,
video_array=joint_video.cpu(),
options={'crf': '15'})
if __name__ == '__main__':
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
second_args_list = copy.deepcopy(args_list)
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
second_args_list[1] = args.second[0]
second_args = get_args(second_args_list)
second_args = argparse.Namespace(**vars(second_args), **vars(known))
del second_args.deepspeed_config
second_args.model_config.first_stage_config.params.cp_size = 1
second_args.model_config.network_config.params.transformer_args.model_parallel_size = 1
second_args.model_config.network_config.params.transformer_args.checkpoint_activations = False
second_args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
sampling_main(args,
second_args,
model_cls=SATVideoDiffusionEngine,
second_model_cls=FlowEngine)
This diff is collapsed.
import copy
import random
import torch
from dit_video_concat import (DiffusionTransformer,
Rotary3DPositionEmbeddingMixin, broadcat,
rotate_half)
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.util import timestep_embedding
from torch import nn
from sat.transformer_defaults import HOOKS_DEFAULT
class ScaleCropRotary3DPositionEmbeddingMixin(Rotary3DPositionEmbeddingMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
h_interp_ratio=1.0,
w_interp_ratio=1.0,
t_interp_ratio=1.0,
rot_v=False,
learnable_pos_embed=False,
):
super(Rotary3DPositionEmbeddingMixin, self).__init__()
self.rot_v = rot_v
print(f'theta is {theta}')
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
self.freqs_t = (1.0 / (theta**(
torch.arange(0, dim_t, 2)[:(dim_t // 2)].float() / dim_t))).cuda()
self.freqs_h = (1.0 / (theta**(
torch.arange(0, dim_h, 2)[:(dim_h // 2)].float() / dim_h))).cuda()
self.freqs_w = (1.0 / (theta**(
torch.arange(0, dim_w, 2)[:(dim_w // 2)].float() / dim_w))).cuda()
self.compressed_num_frames = compressed_num_frames
self.height = height
self.width = width
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(
1, num_patches, int(hidden_size)),
requires_grad=True)
else:
self.pos_embedding = None
def online_sin_cos(self, real_t, real_h, real_w, dy_interpolation=None):
grid_t = torch.arange(real_t, dtype=torch.float32, device='cuda')
grid_h = torch.arange(real_h, dtype=torch.float32, device='cuda')
grid_w = torch.arange(real_w, dtype=torch.float32, device='cuda')
freqs_t = self.freqs_t
freqs_h = self.freqs_h
freqs_w = self.freqs_w
freqs_t = torch.einsum('..., f -> ... f', grid_t, freqs_t)
freqs_h = torch.einsum('..., f -> ... f', grid_h, freqs_h)
freqs_w = torch.einsum('..., f -> ... f', grid_w, freqs_w)
freqs_t = repeat(freqs_t, '... n -> ... (n r)', r=2)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :],
freqs_w[None, None, :, :]),
dim=-1)
freqs = rearrange(freqs, 't h w d -> (t h w) d')
temp_layer_id = self.share_cache['temp_layer_id']
emb = self.share_cache['emb']
if f'rope_layer_{temp_layer_id}' in self.share_cache:
m = self.share_cache[f'rope_layer_{temp_layer_id}']
dy_interpolation = m(emb)
else:
dy_interpolation = freqs.new_ones(emb.shape[0], freqs.shape[-1])
b, dim = dy_interpolation.shape
dy_interpolation = dy_interpolation[:, None]
freqs = freqs[None]
freqs = freqs.repeat(b, 1, 1)
if dy_interpolation.shape[-1] != freqs.shape[-1]:
freqs[...,-dy_interpolation.shape[-1]:] = \
freqs[...,-dy_interpolation.shape[-1]:] * dy_interpolation
else:
freqs = freqs * dy_interpolation
freqs_sin = torch.sin(freqs)
freqs_cos = torch.cos(freqs)
return freqs_cos, freqs_sin
def rotary(self, t, **kwargs):
if 'freqs_cos' in self.share_cache:
freqs_cos = self.share_cache['freqs_cos']
freqs_sin = self.share_cache['freqs_sin']
else:
real_t, real_h, real_w = self.share_cache['shape_info']
freqs_cos, freqs_sin = self.online_sin_cos(real_t,
real_h,
real_w,
dy_interpolation=None)
freqs_cos = freqs_cos.unsqueeze(1)
freqs_cos = freqs_cos.to(t.dtype)
freqs_sin = freqs_sin.unsqueeze(1)
freqs_sin = freqs_sin.to(t.dtype)
self.share_cache['freqs_cos'] = freqs_cos
self.share_cache['freqs_sin'] = freqs_sin
return t * freqs_cos + rotate_half(t) * freqs_sin
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
**kwargs,
):
attention_fn_default = HOOKS_DEFAULT['attention_fn']
img_query_layer = self.rotary(query_layer[:, :, self.text_length:])
query_layer = torch.cat(
[query_layer[:, :, :self.text_length], img_query_layer], dim=2)
query_layer = query_layer.to(value_layer.dtype)
img_key_layer = self.rotary(key_layer[:, :, self.text_length:])
key_layer = torch.cat(
[key_layer[:, :, :self.text_length], img_key_layer], dim=2)
key_layer = key_layer.to(value_layer.dtype)
if self.rot_v:
value_layer[:, :, self.text_length:] = self.rotary(
value_layer[:, :, self.text_length:])
return attention_fn_default(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
from sat.model.finetune.lora2 import *
class ResLoraMixin(LoraMixin):
def reinit(self, parent_model):
for i in self.layer_range:
print_rank0(f'replacing layer {i} attention with lora')
parent_model.transformer.layers[
i].attention.dense = replace_linear_with_lora(
parent_model.transformer.layers[i].attention.dense,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=None)
parent_model.transformer.layers[
i].attention.query_key_value = replace_linear_with_lora(
parent_model.transformer.layers[i].attention.
query_key_value,
parent_model.transformer.layers[i].attention.stride,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=None
if not parent_model.transformer.num_multi_query_heads else
parent_model.transformer.layers[i].attention.
inner_hidden_size + parent_model.transformer.layers[i].
attention.hidden_size_per_attention_head * parent_model.
transformer.layers[i].attention.num_multi_query_heads * 2)
if self.cross_attention and parent_model.transformer.layers[
i].is_decoder:
print_rank0(f'replacing layer {i} cross attention with lora')
kv_size = parent_model.transformer.layers[
i].cross_attention.inner_hidden_size * 2 if not parent_model.transformer.cross_num_multi_query_heads else parent_model.transformer.layers[
i].cross_attention.hidden_size_per_attention_head * parent_model.transformer.layers[
i].cross_attention.cross_num_multi_query_heads * 2
parent_model.transformer.layers[
i].cross_attention.dense = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
dense,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.layers[i].
cross_attention.inner_hidden_size,
out_size=parent_model.transformer.hidden_size)
parent_model.transformer.layers[
i].cross_attention.query = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
query,
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.hidden_size,
out_size=parent_model.transformer.layers[i].
cross_attention.inner_hidden_size)
parent_model.transformer.layers[
i].cross_attention.key_value = replace_linear_with_lora(
parent_model.transformer.layers[i].cross_attention.
key_value,
2,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=parent_model.transformer.layers[i].
cross_attention.cross_attn_hidden_size,
out_size=kv_size)
for m in parent_model.mixins.adaln_layer.adaLN_modulations:
m[1] = replace_linear_with_lora(m[1],
1,
self.r,
self.lora_alpha,
self.lora_dropout,
qlora=self.qlora,
in_size=512,
out_size=36864)
def merge_lora(self):
for i in self.layer_range:
print_rank0(f'merge layer {i} lora attention back to linear')
self.transformer.layers[i].attention.dense = merge_linear_lora(
self.transformer.layers[i].attention.dense)
self.transformer.layers[
i].attention.query_key_value = merge_linear_lora(
self.transformer.layers[i].attention.query_key_value)
if self.cross_attention and self.transformer.layers[i].is_decoder:
print_rank0(
f'merge layer {i} lora cross attention back to linear')
self.transformer.layers[
i].cross_attention.dense = merge_linear_lora(
self.transformer.layers[i].cross_attention.dense)
self.transformer.layers[
i].cross_attention.query = merge_linear_lora(
self.transformer.layers[i].cross_attention.query)
self.transformer.layers[
i].cross_attention.key_value = merge_linear_lora(
self.transformer.layers[i].cross_attention.key_value)
class SMALLDiffusionTransformer(DiffusionTransformer):
def register_new_modules(self):
if 'sample_ref_noise_step' in self.share_cache:
self.ref_step_time_embedding = copy.deepcopy(self.time_embed)
# zero init last linear in the self.ref_step_time_embedding
for n, p in self.ref_step_time_embedding[-1].named_parameters():
nn.init.constant_(p, 0)
p.requires_grad = True
if 'time_size_embedding' in self.share_cache:
self.time_size_embedding = copy.deepcopy(self.time_embed)
# zero init the fuse linear
for n, p in self.time_size_embedding.named_parameters():
nn.init.constant_(p, 0)
p.requires_grad = True
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
if 'ref_x' in self.share_cache:
ref_x = self.share_cache['ref_x']
if ref_x.dtype != self.dtype:
ref_x = ref_x.to(self.dtype)
self.share_cache['ref_x'] = ref_x
# This is not use in inference
if 'concat_images' in kwargs and kwargs['concat_images'] is not None:
if kwargs['concat_images'].shape[0] != x.shape[0]:
concat_images = kwargs['concat_images'].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs['concat_images']
x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == (
self.num_classes is not None
), 'must specify y if and only if the model is class-conditional'
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
emb = self.time_embed(t_emb)
if 'time_size_embedding' in self.share_cache:
num_t = torch.zeros_like(timesteps).fill_(int(t))
time_size_emb = timestep_embedding(num_t,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
time_size_emb = self.time_size_embedding(time_size_emb)
emb = emb + time_size_emb
if 'sample_ref_noise_step' in self.share_cache:
print(
f'''sample_ref_noise_step {self.share_cache["sample_ref_noise_step"]}'''
)
# bf 16
ref_time_step = copy.deepcopy(timesteps).fill_(
self.share_cache['sample_ref_noise_step'])
ref_step_time_emb = timestep_embedding(ref_time_step,
self.model_channels,
repeat_only=False,
dtype=self.dtype)
ref_step_time_emb = self.ref_step_time_embedding(ref_step_time_emb)
if not self.training:
print(f'{ref_time_step} get {ref_step_time_emb.sum()}')
emb = emb + ref_step_time_emb
if self.num_classes is not None:
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
self.share_cache['shape_info'] = (t, h // (self.patch_size),
w // (self.patch_size))
self.share_cache['timesteps'] = int(timesteps[0])
self.share_cache['emb'] = emb
kwargs['seq_length'] = t * h * w // (self.patch_size**2)
kwargs['images'] = x
kwargs['emb'] = emb
kwargs['encoder_outputs'] = context
kwargs['text_length'] = context.shape[1]
kwargs['input_ids'] = kwargs['position_ids'] = kwargs[
'attention_mask'] = torch.ones((1, 1)).to(x.dtype)
output = super(DiffusionTransformer, self).forward(**kwargs)[0]
return output
This diff is collapsed.
from .models import AutoencodingEngine
from .util import get_configs_path, instantiate_from_config
__version__ = '0.1.0'
This diff is collapsed.
from .autoencoder import AutoencodingEngine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment