Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import functools
import importlib
import os
from functools import partial
from inspect import isfunction
import fsspec
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
import torch.distributed
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def set_context_parallel_group(size, group):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_SIZE = size
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, context_parallel_size):
ranks = range(i, i + context_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
class SafeConv3d(torch.nn.Conv3d):
def forward(self, input):
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if memory_count > 2:
# print(f"WARNING: Conv3d with {memory_count:.2f}GB")
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super(SafeConv3d, self).forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super(SafeConv3d, self).forward(input)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == "(" and s[-1] == ")":
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config, **extra_kwargs):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs)
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def get_configs_path() -> str:
"""
Get the `configs` directory.
For a working copy, this is the one in the root of the repository,
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "configs"),
os.path.join(this_dir, "..", "configs"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split(".")
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, "At least one attribute should be selected"
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = ".".join(attributes[: level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute
from math import sqrt
class SeededNoise:
def __init__(self, seeds, weights):
self.seeds = seeds
self.weights = weights
weight_square_sum = 0
for weight in weights:
weight_square_sum += weight**2
self.weight_square_sum_sqrt = sqrt(weight_square_sum)
self.cnt = 0
def __call__(self, x):
self.cnt += 1
randn_combined = torch.zeros_like(x)
for seed, weight in zip(self.seeds, self.weights):
randn = np.random.RandomState(seed + self.cnt).randn(*x.shape)
randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device)
randn_combined += randn * weight
randn_combined /= self.weight_square_sum_sqrt
return randn_combined
import sys
import io
import os
import re
import json
import tarfile
from functools import partial
import webdataset as wds
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples
from webdataset.filters import pipelinefilter
from webdataset.tariterators import url_opener, group_by_keys
from webdataset.handlers import reraise_exception
from webdataset.gopen import gopen_schemes, gopen
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
try:
import torch.distributed
if torch.distributed.is_available() and torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
except ModuleNotFoundError:
pass
try:
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
return rank, world_size, worker, num_workers
def pytorch_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
return rank * 1000 + worker
def worker_seed_sat(group=None, seed=0):
return pytorch_worker_seed(group=group) + seed * 23
class ConfiguredResampledShards(ResampledShards):
def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
from sat.helpers import print_rank0
try:
from megatron.core.parallel_state import get_data_parallel_group
group = get_data_parallel_group()
print_rank0("Using megatron data parallel group.")
except:
from sat.mpu import get_data_parallel_group
try:
group = get_data_parallel_group()
print_rank0("Using sat data parallel group.")
except AssertionError:
group = None
print_rank0("No data parallel group is specified!")
worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
super().__init__(urls, nshards, worker_seed_sat_this, deterministic)
class SimpleDistributedWebDataset(DataPipeline):
def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
# set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
def tar_file_iterator_with_meta(
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None
):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param meta_names: key of different items in meta file
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode="r|*")
data_dir, filename = fileobj.name.rsplit("/", 1)
meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
if meta_stream is None:
meta_file_name = filename.split(".")[0] + ".meta.jsonl"
meta_path = os.path.join(data_dir, meta_file_name)
if os.path.exists(meta_path):
meta_stream = open(meta_path, "r")
else:
meta_file_name = meta_stream.name
if meta_stream is not None:
for lineno, line in enumerate(meta_stream):
meta_list = []
try:
meta_list.append(json.loads(line))
except Exception as exn:
from sat.helpers import print_rank0
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
continue
for item in meta_list:
if not item["key"] in meta_data:
meta_data[item["key"]] = {}
for meta_name in meta_names:
if meta_name in item:
meta_data[item["key"]][meta_name] = item[meta_name]
meta_stream.close()
try:
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if "/" not in fname and fname.startswith("__") and fname.endswith("__"):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
if fname.endswith(".txt") and suffix is not None:
data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
else:
data = stream.extractfile(tarinfo).read()
result = dict(fname=fname, data=data)
yield result
if fname.endswith(".id"):
fid = fname.split(".")[0]
if "-$#%@&" in fid:
sfid = fid.split("-$#%@&")[0]
else:
sfid = fid
meta_data_fid = meta_data.get(sfid, {})
for meta_name in meta_names:
meta_fname = fid + "." + meta_name
meta = meta_data_fid.get(meta_name, None)
yield dict(fname=meta_fname, data=meta)
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn):
continue
else:
break
except Exception as exn:
print(exn)
del stream
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source["url"]
try:
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
sample["__url__"] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get("stream"), source.get("url"))
if handler(exn):
continue
else:
break
def url_opener(
data,
handler,
**kw,
):
"""Open URLs and yield a stream of url+stream pairs.
Args:
data: iterator over dict(url=...)
handler: exception handler.
kw: keyword arguments for gopen.gopen.
Yields:
a stream of url+stream pairs.
"""
for sample in data:
assert isinstance(sample, dict), sample
assert "url" in sample
url = sample["url"]
try:
stream = gopen(url, **kw)
if hasattr(stream, "meta_stream"):
meta_stream = stream.meta_stream
del stream.meta_stream
else:
meta_stream = None
sample.update(stream=stream, meta_stream=meta_stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
streams = url_opener(src, handler=handler)
files = tar_file_expander_with_meta(streams, meta_names, handler)
samples = group_by_keys(files, handler=handler)
return samples
class MetaDistributedWebDataset(DataPipeline):
"""WebDataset with meta information files
Extra Format:
in webdataset (tar), for each sample there is a '.id';
for each tar file, there is a '.meta.jsonl' file with the same name;
The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
"""
def __init__(
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
):
# os.environ['WDS_SHOW_SEED'] = '1'
import torch
if torch.distributed.get_rank() == 0:
if include_dirs is not None: # /webdatasets/A,/webdatasets/C
other_paths = []
include_dirs = include_dirs.split(",")
for include_dir in include_dirs:
if "*" in include_dir:
include_dir, n = include_dir.split("*")
n = int(n)
else:
n = 1
for cur_dir, dirs, files in os.walk(include_dir):
for f in files:
if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0:
# other_paths.append(os.path.join(cur_dir,f))
other_paths.extend([os.path.join(cur_dir, f)] * n)
# print(f'Adding dataset paths {",".join(other_paths)}')
from braceexpand import braceexpand
if len(path) > 0: # not ""
path = list(braceexpand(path)) + other_paths
else:
path = other_paths
path = [path]
else:
path = [
None,
]
torch.distributed.broadcast_object_list(path, src=0)
path = path[0]
tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
tarfile_to_samples = pipelinefilter(tarfile_samples)
# if model parallel, shuffle_buffer should be 1 to disable shuffling
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed, nshards=nshards),
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
# rclone support
from webdataset.gopen import Pipe
def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32):
"""Open a URL with `curl`.
:param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
url = url.replace("rclone://", "")
if mode[0] == "r":
cmd = f"rclone cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"rclone cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
"""Open a URL with boto3 API.
:param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
import boto3
# boto3.set_stream_logger('botocore', level='DEBUG')
if url.startswith("boto3://"):
url = url.replace("boto3://", "")
need_meta = False
else:
url = url.replace("metaboto3://", "")
need_meta = True
endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)
access_key = os.environ.get("S3_ACCESS_KEY_ID", None)
secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)
if mode[0] == "r":
s3_client = boto3.client(
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
)
bucket, key = url.split("/", 1)
if need_meta:
# download a meta json
meta_file_key = key.split(".")[0] + ".meta.jsonl"
meta_stream = io.BytesIO()
s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
meta_stream.seek(0)
meta_stream.name = meta_file_key
else:
meta_stream = None
# data tar stream
response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional
response["Body"].name = key # actually not used
response["Body"].meta_stream = meta_stream
return response["Body"]
else:
raise ValueError(f"{mode}: unknown mode")
gopen_schemes["rclone"] = gopen_rclone
gopen_schemes["boto3"] = gopen_boto3
gopen_schemes["metaboto3"] = gopen_boto3
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from modules.utils import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode="softmax",
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
if not XFORMERS_IS_AVAILABLE:
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
kwargs = {"x": x}
if context is not None:
kwargs.update({"context": context})
if additional_tokens is not None:
kwargs.update({"additional_tokens": additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode="softmax",
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type="softmax",
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
import logging
import math
import re
import random
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed
import torch.nn as nn
from einops import rearrange
from packaging import version
from vae_modules.ema import LitEma
from sgm.util import (
instantiate_from_config,
get_obj_from_str,
default,
is_context_parallel_initialized,
initialize_context_parallel,
get_context_parallel_group,
get_context_parallel_group_rank,
)
from vae_modules.cp_enc_dec import _conv_split, _conv_gather
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
# def apply_ckpt(self, ckpt: Union[None, str, dict]):
# if ckpt is None:
# return
# if isinstance(ckpt, str):
# ckpt = {
# "target": "sgm.modules.checkpoint.CheckpointEngine",
# "params": {"ckpt_path": ckpt},
# }
# engine = instantiate_from_config(ckpt)
# engine(self)
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.automatic_optimization = False # pytorch lightning
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = []
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, "get_trainable_parameters"):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {"train/loss/rec": aeloss.detach()}
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": 0,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "val" + postfix,
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
full_log_dict = log_dict_ae
if "optimizer_idx" in extra_info:
extra_info["optimizer_idx"] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f"val{postfix}/loss/rec",
log_dict_ae[f"val{postfix}/loss/rec"],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
def get_param_groups(
self, parameter_names: List[List[str]], optimizer_args: List[dict]
) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(f"Did not find parameters for pattern {pattern_}")
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
log["reconstructions"] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log["reconstructions_ema"] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = "reconstructions-" + "-".join(
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
)
log[log_str] = xrec_add
return log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__(
encoder_config={
"target": "sgm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
class VideoAutoencodingEngine(AutoencodingEngine):
def __init__(
self,
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list] = (),
image_video_weights=[1, 1],
only_train_decoder=False,
context_parallel_size=0,
**kwargs,
):
super().__init__(**kwargs)
self.context_parallel_size = context_parallel_size
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
if self.context_parallel_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.context_parallel_size)
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch
return batch[self.input_key]
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
def __init__(
self,
cp_size=0,
*args,
**kwargs,
):
self.cp_size = cp_size
return super().__init__(*args, **kwargs)
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
input_cp: bool = False,
output_cp: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size > 0 and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized)
else:
z = super().encode(x, return_reg_log, unregularized)
if self.cp_size > 0 and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
return z, reg_log
return z
def decode(
self,
z: torch.Tensor,
input_cp: bool = False,
output_cp: bool = False,
split_kernel_size: int = 1,
**kwargs,
):
if self.cp_size > 0 and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
x = super().decode(z, **kwargs)
if self.cp_size > 0 and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
return x
def forward(
self,
x: torch.Tensor,
input_cp: bool = False,
latent_cp: bool = False,
output_cp: bool = False,
**additional_decode_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
return z, dec, reg_log
import math
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from sgm.util import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_group_rank,
)
# try:
from vae_modules.utils import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def _split(input_, dim):
cp_world_size = get_context_parallel_world_size()
if cp_world_size == 1:
return input_
cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
input_list = torch.split(input_, dim_size, dim=dim)
output = input_list[cp_rank]
if cp_rank == 0:
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _gather(input_, dim):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_frame_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
return input_
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
# req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
if cache_padding is not None:
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
else:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_split(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_gather(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size, cache_padding):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
def conv_pass_from_last_rank(input_, dim, kernel_size):
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
class ContextParallelCausalConv3d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
time_pad = time_kernel_size - 1
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_kernel_size = time_kernel_size
self.temporal_dim = 2
stride = (stride, stride, stride)
dilation = (1, 1, 1)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
self.cache_padding = None
def forward(self, input_, clear_cache=True):
# if input_.shape[2] == 1: # handle image
# # first frame padding
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
# else:
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
# output_parallel = self.conv(input_parallel)
# output = output_parallel
# return output
input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
)
del self.cache_padding
self.cache_padding = None
if not clear_cache:
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
global_rank = torch.distributed.get_rank()
if cp_world_size == 1:
self.cache_padding = (
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
)
else:
if cp_rank == cp_world_size - 1:
torch.distributed.isend(
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
global_rank + 1 - cp_world_size,
group=get_context_parallel_group(),
)
if cp_rank == 0:
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
torch.distributed.recv(
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
)
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class ContextParallelGroupNorm(torch.nn.GroupNorm):
def forward(self, input_):
gather_flag = input_.shape[2] > 1
if gather_flag:
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
output = super().forward(input_)
if gather_flag:
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
gather=False,
**norm_layer_params,
):
super().__init__()
if gather:
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if add_conv:
self.conv = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=zq_channels,
kernel_size=3,
)
self.conv_y = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
self.conv_b = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
def forward(self, f, zq, clear_fake_cp_cache=True):
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(
in_channels,
zq_ch,
add_conv,
gather=False,
):
return SpatialNorm3D(
in_channels,
zq_ch,
gather=gather,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample3D(nn.Module):
def __init__(
self,
in_channels,
with_conv,
compress_time=False,
):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
if x.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class DownSample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class ContextParallelResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=Normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalization(
in_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.conv1 = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ContextParallelCausalConv3d(
chan_in=out_channels,
chan_out=out_channels,
kernel_size=3,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
else:
self.nin_shortcut = Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache)
else:
x = self.nin_shortcut(x)
return x + h
class ContextParallelEncoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=self.ch,
kernel_size=3,
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
gather_norm=gather_norm,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
else:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
# end
self.norm_out = Normalize(block_in, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=2 * z_channels if double_z else z_channels,
kernel_size=3,
)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class ContextParallelDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
chan_out=block_in,
kernel_size=3,
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=out_ch,
kernel_size=3,
)
def forward(self, z, clear_fake_cp_cache=True):
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
from abc import abstractmethod
from typing import Any, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
# device=self.parameters.device
# )
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
import functools
import importlib
import os
from functools import partial
from inspect import isfunction
import fsspec
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
import torch.distributed
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, context_parallel_size):
ranks = range(i, i + context_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
class SafeConv3d(torch.nn.Conv3d):
def forward(self, input):
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super(SafeConv3d, self).forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super(SafeConv3d, self).forward(input)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == "(" and s[-1] == ")":
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def get_configs_path() -> str:
"""
Get the `configs` directory.
For a working copy, this is the one in the root of the repository,
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "configs"),
os.path.join(this_dir, "..", "configs"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split(".")
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, "At least one attribute should be selected"
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = ".".join(attributes[: level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
# Video Caption
Typically, most video data does not come with corresponding descriptive text, so it is necessary to convert the video
data into textual descriptions to provide the essential training data for text-to-video models.
## Video Caption via CogVLM2-Video
<p align="center">
🤗 <a href="https://huggingface.co/THUDM/cogvlm2-video-llama3-chat">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://cogvlm2-video.github.io/">Blog</a> &nbsp&nbsp | <a href="http://cogvlm2-online.cogviewai.cn:7868/">💬 Online Demo</a>&nbsp&nbsp
</p>
CogVLM2-Video is a versatile video understanding model equipped with timestamp-based question answering capabilities.
Users can input prompts such as `Please describe this video in detail.` to the model to obtain a detailed video caption:
<div align="center">
<a href="https://cogvlm2-video.github.io/"><img width="600px" height="auto" src="./assests/cogvlm2-video-example.png"></a>
</div>
Users can use the provided [code](https://github.com/THUDM/CogVLM2/tree/main/video_demo) to load the model or configure a RESTful API to generate video captions.
\ No newline at end of file
# 视频Caption
通常,大多数视频数据不带有相应的描述性文本,因此需要将视频数据转换为文本描述,以提供必要的训练数据用于文本到视频模型。
## 通过 CogVLM2-Video 模型生成视频Caption
🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-video-llama3-chat) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-video-llama3-chat) | 📑 [Blog](https://cogvlm2-video.github.io/)[💬 Online Demo](http://cogvlm2-online.cogviewai.cn:7868/)
CogVLM2-Video 是一个多功能的视频理解模型,具备基于时间戳的问题回答能力。用户可以输入诸如 `请详细描述这个视频` 的提示语给模型,以获得详细的视频Caption:
<div align="center">
<a href="https://cogvlm2-video.github.io/"><img width="600px" height="auto" src="./assests/cogvlm2-video-example.png"></a>
</div>
用户可以使用提供的[代码](https://github.com/THUDM/CogVLM2/tree/main/video_demo)加载模型或配置 RESTful API 来生成视频Caption。
\ No newline at end of file
"""
This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
Note:
This script requires the `diffusers>=0.30.0` library to be installed.
Run the script:
$ python convert_and_generate.py --transformer_ckpt_path <path_to_transformer_checkpoint> --vae_ckpt_path <path_to_vae_checkpoint> --output_path <path_to_output_directory> --text_encoder_path <path_to_t5>
Functions:
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
- get_state_dict: Extracts the state_dict from a saved checkpoint.
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
- get_args: Parses command-line arguments for the script.
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
"""
import argparse
from typing import Any, Dict
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from transformers import T5EncoderModel, T5Tokenizer
# Function to reassign the query, key, and value weights in-place
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
to_v_key = key.replace("query_key_value", "to_v")
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
state_dict[to_q_key] = to_q
state_dict[to_k_key] = to_k
state_dict[to_v_key] = to_v
state_dict.pop(key)
# Function to reassign layer normalization for query and key in-place
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
if "query" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
elif "key" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
state_dict[new_key] = state_dict.pop(key)
# Function to reassign adaptive layer normalization in-place
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
weights_or_biases = state_dict[key].chunk(12, dim=0)
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
state_dict[norm1_key] = norm1_weights_or_biases
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
state_dict[norm2_key] = norm2_weights_or_biases
state_dict.pop(key)
# Function to remove keys from state_dict in-place
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
# Function to replace keys in the "up" block in-place
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
replace_layer_index = 4 - 1 - layer_index
key_split[1] = "up_blocks"
key_split[2] = str(replace_layer_index)
new_key = ".".join(key_split)
state_dict[new_key] = state_dict.pop(key)
# Dictionary for renaming transformer keys
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
"attention": "attn1",
"mlp": "ff.net",
"dense_h_to_4h": "0.proj",
"dense_4h_to_h": "2",
".layers": "",
"dense": "to_out.0",
"input_layernorm": "norm1.norm",
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
}
# Dictionary for handling special keys in transformer
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
}
# Dictionary for renaming VAE keys
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"nin_shortcut": "conv_shortcut",
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
# Dictionary for handling special keys in VAE
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
# Maximum length of the tokenizer (Must be 226)
TOKENIZER_MAX_LENGTH = 226
# Function to extract the state_dict from a saved checkpoint
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
# Function to update the state_dict with new key assignments in-place
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
# Function to convert a transformer checkpoint to the CogVideoX format
def convert_transformer(ckpt_path: str):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
# Function to convert a VAE checkpoint to the CogVideoX format
def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
# Function to parse command-line arguments for the script
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument(
"--text_encoder_path",
type=str,
required=True,
default="google/t5-v1_1-xxl",
help="Path where converted model should be saved",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
)
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
}
)
pipe = CogVideoXPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
# coding=utf-8
# rewritten, Copyright (c) 2021, Ming Ding. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from sat import mpu
from sat.mpu import get_model_parallel_world_size, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, gather_from_model_parallel_region, copy_to_model_parallel_region, checkpoint
from sat.mpu.utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
from sat.ops.layernorm import LayerNorm
from sat.transformer_defaults import HOOKS_DEFAULT, standard_attention, split_tensor_along_last_dim
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, qkv_bias=False, num_multi_query_heads=0, row_parallel_linear_final_bias=True,
hooks={}, transformer_pointer=None, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(SelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_multi_query_heads = num_multi_query_heads
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.num_multi_query_heads_per_partition = divide(num_multi_query_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
if num_multi_query_heads == 0:
qkv_size = 3 * self.inner_hidden_size
self.stride = 3
else: # multi-query
qkv_size = self.inner_hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
self.stride = [self.num_attention_heads_per_partition, self.num_multi_query_heads_per_partition, self.num_multi_query_heads_per_partition]
self.query_key_value = ColumnParallelLinear(
hidden_size,
qkv_size,
stride=self.stride,
gather_output=False,
init_method=init_method,
bias=bias or qkv_bias,
params_dtype=params_dtype,
module=self,
name="query_key_value",
skip_init=skip_init,
device=device
)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(-1, # flexible for multi-query
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, *args, **kw_args):
if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
else:
return HOOKS_DEFAULT['attention_forward'](self, hidden_states, mask, **kw_args)
def repartition(self):
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(self.num_attention_heads, world_size)
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
class CrossAttention(torch.nn.Module):
"""Parallel cross-attention layer for Transformer"""
def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method,
layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, cross_num_multi_query_heads=0, row_parallel_linear_final_bias=True, hooks={},
cross_attn_hidden_size=None, transformer_pointer=None, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super().__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
self.cross_num_multi_query_heads = cross_num_multi_query_heads
# Strided linear layer.
if cross_num_multi_query_heads == 0:
kv_size = 2 * self.inner_hidden_size
else: # multi-query
kv_size = self.hidden_size_per_attention_head * self.cross_num_multi_query_heads * 2
self.query = ColumnParallelLinear(hidden_size, self.inner_hidden_size,
gather_output=False,
init_method=init_method, bias=bias, params_dtype=params_dtype, module=self, name="query", skip_init=skip_init, device=device)
if cross_attn_hidden_size is None:
cross_attn_hidden_size = hidden_size
self.cross_attn_hidden_size = cross_attn_hidden_size
self.key_value = ColumnParallelLinear(cross_attn_hidden_size, kv_size,
stride=2,
gather_output=False,
init_method=init_method, bias=bias, params_dtype=params_dtype, module=self, name="key_value",
skip_init=skip_init, device=device)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method, bias=bias, params_dtype=params_dtype, module=self, name="dense",skip_init=skip_init,
device=device, final_bias=row_parallel_linear_final_bias)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(-1, # flexible for multi-query
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# hidden_states: [b, s, h]
if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
else:
return HOOKS_DEFAULT['cross_attention_forward'](self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
def repartition(self):
world_size = get_model_parallel_world_size()
self.num_attention_heads_per_partition = divide(self.num_attention_heads, world_size)
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
class MLP(torch.nn.Module):
def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
output_layer_init_method=None, layer_id=None, row_parallel_linear_final_bias=True, hooks={}, bias=True, activation_func=gelu, transformer_pointer=None, is_gated_mlp=False, num_experts=1,
params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(MLP, self).__init__()
self.layer_id = layer_id
self.activation_func = activation_func
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
# Project to 4h.
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h",
skip_init=skip_init,
device=device
)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_4h_to_h",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.is_gated_mlp = is_gated_mlp
if is_gated_mlp:
self.dense_h_to_4h_gate = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h_gate",
skip_init=skip_init,
device=device
)
self.num_experts = num_experts
for i in range(1, num_experts):
self.register_module(f"dense_h_to_4h_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_{i}",
skip_init=skip_init,
device=device
))
# Project back to h.
self.register_module(f"dense_4h_to_h_{i}", RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_4h_to_h_{i}",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
))
if is_gated_mlp:
self.register_module(f"dense_h_to_4h_gate_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_gate_{i}",
skip_init=skip_init,
device=device
))
self.dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_args)
else:
output = HOOKS_DEFAULT['mlp_forward'](self, hidden_states, **kw_args)
if self.training:
output = self.dropout(output)
return output
# Spatial LIEM
class SpatialAttention(nn.Module): # b c h w
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_out, _ = torch.max(x, dim=1, keepdim=True)
avg_out = torch.mean(x, dim=1, keepdim=True)
weight = torch.cat([max_out, avg_out], dim=1)
weight = self.conv1(weight)
out = self.sigmoid(weight) * x
return out
# Temporal LIEM
class TemporalLocalAttention(nn.Module): # b t c
def __init__(self):
super(TemporalLocalAttention, self).__init__()
self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_out, _ = torch.max(x, dim=-1, keepdim=True)
avg_out = torch.mean(x, dim=-1, keepdim=True)
weight = torch.cat([max_out, avg_out], dim=-1)
weight = self.conv1(weight)
out = self.sigmoid(weight) * x
return out
# Spatial-Temporal LIEM
class LocalAttention(nn.Module): # b c t h w
def __init__(self):
super(LocalAttention, self).__init__()
self.conv1 = nn.Conv3d(in_channels=2, out_channels=1, kernel_size=7, padding=7//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_out, _ = torch.max(x, dim=1, keepdim=True)
avg_out = torch.mean(x, dim=1, keepdim=True)
weight = torch.cat([max_out, avg_out], dim=1)
weight = self.conv1(weight)
out = self.sigmoid(weight) * x
return out
class BaseTransformerLayer(torch.nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
output_layer_init_method=None,
layernorm_order='pre',
layernorm=LayerNorm,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
drop_path=0,
activation_func=gelu,
is_gated_mlp=False,
num_experts=1,
hooks={},
transformer_pointer=None,
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')
):
super(BaseTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.layer_id = layer_id
self.is_decoder = is_decoder[layer_id] if type(is_decoder) is list else is_decoder
self.layernorm_order = layernorm_order
self.drop_path = drop_path
self.hooks = hooks
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None
# Layernorm on the input data.
self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = SelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
# Layernorm on the input data.
self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
if self.layernorm_order == 'sandwich':
self.third_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
self.fourth_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# Cross attention.
if self.is_decoder:
self.cross_attention = CrossAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
layer_id,
hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
output_layer_init_method=output_layer_init_method,
cross_attn_hidden_size=cross_attn_hidden_size,
bias=use_bias,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
params_dtype=params_dtype
)
self.post_cross_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
# MLP
self.mlp = MLP(
hidden_size,
output_dropout_prob,
init_method,
inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method,
bias=use_bias,
layer_id=layer_id,
activation_func=activation_func,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
hooks=hooks,
transformer_pointer=transformer_pointer,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
# Spatial & Temporal LIEM
self.spa_local = SpatialAttention()
self.temp_local = TemporalLocalAttention()
# self.liem = LocalAttention()
def forward(self, hidden_states, mask, *args, **kw_args):
return HOOKS_DEFAULT['layer_forward'](self, hidden_states, mask, *args, **kw_args)
class BaseTransformer(torch.nn.Module):
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length,
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path=0,
checkpoint_activations=False,
checkpoint_num_layers=1,
checkpoint_skip_layers=0,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
layernorm_order='pre',
parallel_output=False,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
activation_func=gelu,
is_gated_mlp=False,
is_rotary_emb=False,
num_experts=1,
layernorm=LayerNorm,
init_method=None,
use_final_layernorm=True,
hooks={},
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')
):
super(BaseTransformer, self).__init__()
# recording parameters
self.hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
self.is_decoder = is_decoder
self.cross_attn_hidden_size = cross_attn_hidden_size
self.cross_num_multi_query_heads = cross_num_multi_query_heads
if not is_decoder and cross_attn_hidden_size is not None:
print('warning: cross_attn_hidden_size is set but is_decoder is False')
self.use_bias = use_bias
self.use_qkv_bias = use_qkv_bias
self.num_multi_query_heads = num_multi_query_heads
self.is_gated_mlp = is_gated_mlp
self.is_rotary_emb = is_rotary_emb
self.num_experts = num_experts
self.use_final_layernorm = use_final_layernorm
self.layernorm_epsilon = layernorm_epsilon
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.checkpoint_skip_layers = checkpoint_skip_layers
assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
self.max_sequence_length = max_sequence_length
self.layernorm_order = layernorm_order
self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
self.hooks = copy.copy(hooks) # hooks will be updated each forward
object.__setattr__(self, 'transformer', self) # to give the default hooks the same api as outer hooks
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if vocab_size < 1000:
self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, dtype=params_dtype, device=device)
torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std)
else:
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=vocab_size, embedding_dim=hidden_size,
params_dtype=params_dtype, skip_init=skip_init, device=device)
if self.is_rotary_emb:
from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
self.position_embeddings = FastRotaryEmbedding(hidden_size // num_attention_heads)
else:
self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
# create all layers
if init_method is None:
self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
cross_hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
cross_attn_hidden_size=cross_attn_hidden_size,
layernorm_order=layernorm_order,
layernorm=layernorm,
use_bias=use_bias,
use_qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
drop_path=drop_path,
activation_func=activation_func,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
hooks=self.hooks,
transformer_pointer=self,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
if use_final_layernorm:
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *,
output_hidden_states=False, **kw_args):
# sanity check
assert len(input_ids.shape) >= 2
batch_size, query_length = input_ids.shape[:2]
if attention_mask is None:
# Definition: None means full attention
attention_mask = torch.ones(1, 1, device=input_ids.device)
elif isinstance(attention_mask, int) and (attention_mask < 0):
# Definition: -1 means lower triangular attention mask
attention_mask = torch.ones(query_length, query_length,
device=input_ids.device).tril()
attention_mask = attention_mask.type_as(
next(self.parameters())
)
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# initial output_cross_layer might be generated by word/position_embedding_forward
output_cross_layer = {}
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, output_cross_layer=output_cross_layer, **kw_args)
else: # default
hidden_states = HOOKS_DEFAULT['word_embedding_forward'](self, input_ids, output_cross_layer=output_cross_layer,**kw_args)
# handle position embedding
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](position_ids, output_cross_layer=output_cross_layer, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == hidden_states.shape[1], (position_ids.shape, hidden_states.shape)
position_embeddings = HOOKS_DEFAULT['position_embedding_forward'](self, position_ids, output_cross_layer=output_cross_layer, **kw_args)
if position_embeddings is not None:
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
output_per_layers = []
if self.checkpoint_activations:
# define custom_forward for checkpointing
def custom(start, end, kw_args_index, cross_layer_index):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer
flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {}
for k, idx in kw_args_index.items():
kw_args[k] = flat_inputs[idx]
for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx]
# -----------------
output_per_layers_part = []
for i, layer in enumerate(layers_):
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(
x_, mask, layer_id=layer.layer_id,
**kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj,
output_cross_layer=output_cross_layer_obj
)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = x_
output_per_layers_part.append(output_this_layer)
# flatten for re-aggregate keywords outputs
flat_outputs = []
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
# TODO add warning for depth>=2 grad tensors
flat_outputs.append(output_this_layer[k])
output_this_layer[k] = len(flat_outputs) - 1
for k in output_cross_layer:
flat_outputs.append(output_cross_layer[k])
output_cross_layer[k] = len(flat_outputs) - 1
# --------------------
return (x_, output_per_layers_part, output_cross_layer, *flat_outputs)
return custom_forward
# prevent to lose requires_grad in checkpointing.
# To save memory when only finetuning the final layers, don't use checkpointing.
if self.training:
hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers
output_this_layer = []
while l < num_layers:
args = [hidden_states, attention_mask]
# flatten kw_args and output_cross_layer
flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
for k, v in kw_args.items():
flat_inputs.append(v)
kw_args_index[k] = len(flat_inputs) - 1
for k, v in output_cross_layer.items():
flat_inputs.append(v)
cross_layer_index[k] = len(flat_inputs) - 1
# --------------------
if l + self.checkpoint_skip_layers >= num_layers:
# no checkpointing
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs)
else:
hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \
checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
# recover output_per_layers_part, output_cross_layer
for output_this_layer in output_per_layers_part:
for k in output_this_layer:
output_this_layer[k] = flat_outputs[output_this_layer[k]]
for k in output_cross_layer:
output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
# --------------------
output_per_layers.extend(output_per_layers_part)
l += chunk_length
else:
output_this_layer = []
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
output_this_layer_obj, output_cross_layer_obj = {}, {}
if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args,
layer_id=torch.tensor(i),
**kw_args,
position_ids=position_ids,
**output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj
)
else:
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, position_ids=position_ids, **output_cross_layer,
output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj)
if isinstance(layer_ret, tuple):
layer_ret = layer_ret[0] # for legacy API
hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj
if output_hidden_states:
output_this_layer['hidden_states'] = hidden_states
output_per_layers.append(output_this_layer)
# Final layer norm.
if self.use_final_layernorm:
logits = self.final_layernorm(hidden_states)
else:
logits = hidden_states
logits = copy_to_model_parallel_region(logits)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](logits, **kw_args, parallel_output=self.parallel_output)
else:
logits_parallel = HOOKS_DEFAULT['final_forward'](self, logits, **kw_args, parallel_output=self.parallel_output)
outputs = [logits_parallel]
outputs.extend(output_per_layers)
return outputs
import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from huggingface_hub import hf_hub_download, snapshot_download
snapshot_download(repo_id="SherryX/STAR", local_dir='pretrained_weight')
import os
import subprocess
import tempfile
import cv2
import torch
from PIL import Image
from typing import Mapping
from einops import rearrange
import numpy as np
import torchvision.transforms.functional as transforms_F
from video_to_video.utils.logger import get_logger
logger = get_logger()
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
return images
def preprocess(input_frames):
out_frame_list = []
for pointer in range(len(input_frames)):
frame = input_frames[pointer]
frame = frame[:, :, ::-1]
frame = Image.fromarray(frame.astype('uint8')).convert('RGB')
frame = transforms_F.to_tensor(frame)
out_frame_list.append(frame)
out_frames = torch.stack(out_frame_list, dim=0)
out_frames.clamp_(0, 1)
mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1))
return out_frames
def adjust_resolution(h, w, up_scale):
if h*up_scale < 720:
up_s = 720/h
target_h = int(up_s*h//2*2)
target_w = int(up_s*w//2*2)
elif h*w*up_scale*up_scale > 1280*2048:
up_s = np.sqrt(1280*2048/(h*w))
target_h = int(up_s*h//2*2)
target_w = int(up_s*w//2*2)
else:
target_h = int(up_scale*h//2*2)
target_w = int(up_scale*w//2*2)
return (target_h, target_w)
def make_mask_cond(in_f_num, interp_f_num):
mask_cond = []
interp_cond = [-1 for _ in range(interp_f_num)]
for i in range(in_f_num):
mask_cond.append(i)
if i != in_f_num - 1:
mask_cond += interp_cond
return mask_cond
def load_video(vid_path):
capture = cv2.VideoCapture(vid_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
pointer = 0
frame_list = []
stride = 1
while len(frame_list) < _total_frame_num:
ret, frame = capture.read()
pointer += 1
if (not ret) or (frame is None):
break
if pointer >= _total_frame_num + 1:
break
if pointer % stride == 0:
frame_list.append(frame)
capture.release()
return frame_list, _fps
def save_video(video, save_dir, file_name, fps=16.0):
output_path = os.path.join(save_dir, file_name)
images = [(img.numpy()).astype('uint8') for img in video]
temp_dir = tempfile.mkdtemp()
for fid, frame in enumerate(images):
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
cv2.imwrite(tpth, frame[:, :, ::-1])
tmp_path = os.path.join(save_dir, 'tmp.mp4')
cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
-vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}'
status, output = subprocess.getstatusoutput(cmd)
if status != 0:
logger.error('Save Video Error with {}'.format(output))
os.system(f'rm -rf {temp_dir}')
os.rename(tmp_path, output_path)
def collate_fn(data, device):
"""Prepare the input just before the forward function.
This method will move the tensors to the right device.
Usually this method does not need to be overridden.
Args:
data: The data out of the dataloader.
device: The device to move data to.
Returns: The processed data.
"""
from torch.utils.data.dataloader import default_collate
def get_class_name(obj):
return obj.__class__.__name__
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({
k: collate_fn(v, device) if k != 'img_metas' else v
for k, v in data.items()
})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:
return type(data)(collate_fn(v, device) for v in data)
elif isinstance(data, np.ndarray):
if data.dtype.type is np.str_:
return data
else:
return collate_fn(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
return data
else:
raise ValueError(f'Unsupported data type {type(data)}')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment