Commit 09d54a38 authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #2695 failed with stages
in 0 seconds
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch.nn as nn
import torch
from typing import List
class LLaMaEmbedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
cfg,
):
super().__init__()
self.hidden_size = cfg.hidden_size
self.params_dtype = cfg.params_dtype
self.fp32_residual_connection = cfg.fp32_residual_connection
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
self.word_embeddings = torch.nn.Embedding(
cfg.padded_vocab_size, self.hidden_size,
)
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
def forward(self, input_ids):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
class StepChatTokenizer:
"""Step Chat Tokenizer"""
def __init__(
self, model_file, name="StepChatTokenizer",
bot_token="<|BOT|>", # Begin of Turn
eot_token="<|EOT|>", # End of Turn
call_start_token="<|CALL_START|>", # Call Start
call_end_token="<|CALL_END|>", # Call End
think_start_token="<|THINK_START|>", # Think Start
think_end_token="<|THINK_END|>", # Think End
mask_start_token="<|MASK_1e69f|>", # Mask start
mask_end_token="<|UNMASK_1e69f|>", # Mask end
):
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for idx in range(self._tokenizer.get_piece_size()):
text = self._tokenizer.id_to_piece(idx)
self._inv_vocab[idx] = text
self._vocab[text] = idx
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
self._special_tokens[text] = idx
self._inv_special_tokens[idx] = text
self._unk_id = self._tokenizer.unk_id()
self._bos_id = self._tokenizer.bos_id()
self._eos_id = self._tokenizer.eos_id()
for token in [
bot_token, eot_token, call_start_token, call_end_token,
think_start_token, think_end_token
]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
assert token in self._special_tokens, f"Token '{token}' is not a special token"
for token in [mask_start_token, mask_end_token]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
self._bot_id = self._tokenizer.piece_to_id(bot_token)
self._eot_id = self._tokenizer.piece_to_id(eot_token)
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
self._underline_id = self._tokenizer.piece_to_id("\u2581")
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def tokenize(self, text: str) -> List[int]:
return self._tokenizer.encode_as_ids(text)
def detokenize(self, token_ids: List[int]) -> str:
return self._tokenizer.decode_ids(token_ids)
class Tokens:
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
self.cu_input_ids = cu_input_ids
self.cu_seqlens = cu_seqlens
self.max_seq_len = max_seq_len
def to(self, device):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.cu_input_ids = self.cu_input_ids.to(device)
self.cu_seqlens = self.cu_seqlens.to(device)
return self
class Wrapped_StepChatTokenizer(StepChatTokenizer):
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
# [bos, ..., eos, pad, pad, ..., pad]
self.BOS = 1
self.EOS = 2
self.PAD = 2
out_tokens = []
attn_mask = []
if len(text) == 0:
part_tokens = [self.BOS] + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
else:
for part in text:
part_tokens = self.tokenize(part)
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
part_tokens = [self.BOS] + part_tokens + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
# padding y based on tp size
padded_len = 0
padded_flag = True if padded_len > 0 else False
if padded_flag:
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
# cu_seqlens
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
seqlen = attn_mask.sum(dim=1).tolist()
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
max_seq_len = max(seqlen)
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
\ No newline at end of file
from .utils import *
from .video_process import *
\ No newline at end of file
import numpy as np
import random
import torch
from functools import wraps
import torch.utils._device
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
def with_empty_init(func):
@wraps(func)
def wrapper(*args, **kwargs):
with EmptyInitOnDevice('cpu'):
return func(*args, **kwargs)
return wrapper
def culens2mask(
cu_seqlens=None,
cu_seqlens_kv=None,
max_seqlen=None,
max_seqlen_kv=None,
is_causal=False
):
assert len(cu_seqlens) == len(cu_seqlens_kv); "q k v should have same bsz..."
bsz = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:]-cu_seqlens[:-1]
seqlens_kv = cu_seqlens_kv[1:]-cu_seqlens_kv[:-1]
attn_mask = torch.zeros(bsz, max_seqlen, max_seqlen_kv, dtype=torch.bool)
for i, (seq_len, seq_len_kv) in enumerate(zip(seqlens, seqlens_kv)):
if is_causal:
attn_mask[i, :seq_len, :seq_len_kv] = torch.triu(torch.ones(seq_len, seq_len_kv), diagonal=1).bool()
else:
attn_mask[i, :seq_len, :seq_len_kv] = torch.ones([seq_len, seq_len_kv], dtype=torch.bool)
return attn_mask
import numpy as np
import datetime
import torch
import os
import imageio
class VideoProcessor:
def __init__(self, save_path: str='./results', name_suffix: str=''):
self.save_path = save_path
os.makedirs(self.save_path, exist_ok=True)
self.name_suffix = name_suffix
def crop2standard540p(self, vid_array):
_, height, width, _ = vid_array.shape
height_center = height//2
width_center = width//2
if width_center>height_center: ## horizon mode
return vid_array[:, height_center-270:height_center+270, width_center-480:width_center+480]
elif width_center<height_center: ## portrait mode
return vid_array[:, height_center-480:height_center+480, width_center-270:width_center+270]
else:
return vid_array
def save_imageio_video(self, video_array: np.array, output_filename: str, fps=25, codec='libx264'):
ffmpeg_params = [
"-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1", # denoise
]
with imageio.get_writer(output_filename, fps=fps, codec=codec, ffmpeg_params=ffmpeg_params) as vid_writer:
for img_array in video_array:
vid_writer.append_data(img_array)
def postprocess_video(self, video_tensor, output_file_name='', output_type="mp4", crop2standard540p=True):
if len(self.name_suffix) == 0:
video_path = os.path.join(self.save_path, f"{output_file_name}-{str(datetime.datetime.now())}.{output_type}")
else:
video_path = os.path.join(self.save_path, f"{output_file_name}-{self.name_suffix}.{output_type}")
video_tensor = (video_tensor.cpu().clamp(-1, 1)+1)*127.5
video_tensor = torch.cat([t for t in video_tensor], dim=-2)
video_array = video_tensor.clamp(0, 255).to(torch.uint8).numpy().transpose(0,2,3,1)
if crop2standard540p:
video_array = self.crop2standard540p(video_array)
self.save_imageio_video(video_array, video_path)
print(f"Saved the generated video in {video_path}")
\ No newline at end of file
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from stepvideo.utils import with_empty_init
def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
if hasattr(base_group_norm, 'spatial') and base_group_norm.spatial:
assert channel_last == True
x_shape = x.shape
x = x.flatten(0, 1)
if channel_last:
# Permute to NCHW format
x = x.permute(0, 3, 1, 2)
out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 1)
out = out.view(x_shape)
else:
if channel_last:
# Permute to NCHW format
x = x.permute(0, 4, 1, 2, 3)
out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 4, 1)
return out
def base_conv2d(x, conv_layer, channel_last=False, residual=None):
if channel_last:
x = x.permute(0, 3, 1, 2) # NHWC to NCHW
out = F.conv2d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
if residual is not None:
if channel_last:
residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 1) # NCHW to NHWC
return out
def base_conv3d(x, conv_layer, channel_last=False, residual=None, only_return_output=False):
if only_return_output:
size = cal_outsize(x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding)
return torch.empty(size, device=x.device, dtype=x.dtype)
if channel_last:
x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out = F.conv3d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
if residual is not None:
if channel_last:
residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC
return out
def cal_outsize(input_sizes, kernel_sizes, stride, padding):
stride_d, stride_h, stride_w = stride
padding_d, padding_h, padding_w = padding
dilation_d, dilation_h, dilation_w = 1, 1, 1
in_d = input_sizes[1]
in_h = input_sizes[2]
in_w = input_sizes[3]
in_channel = input_sizes[4]
kernel_d = kernel_sizes[2]
kernel_h = kernel_sizes[3]
kernel_w = kernel_sizes[4]
out_channels = kernel_sizes[0]
out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d)
out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h)
out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w)
size = [input_sizes[0], out_d, out_h, out_w, out_channels]
return size
def calc_out_(in_size, padding, dilation, kernel, stride):
return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1
def base_conv3d_channel_last(x, conv_layer, residual=None):
in_numel = x.numel()
out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels)
if (in_numel >= 2**30) or (out_numel >= 2**30):
assert conv_layer.stride[0] == 1, "time split asks time stride = 1"
B,T,H,W,C = x.shape
K = conv_layer.kernel_size[0]
chunks = 4
chunk_size = T // chunks
if residual is None:
out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual, only_return_output=True)
else:
out_nhwc = residual
assert B == 1
outs = []
for i in range(chunks):
if i == chunks-1:
xi = x[:1,chunk_size*i:]
out_nhwci = out_nhwc[:1,chunk_size*i:]
else:
xi = x[:1,chunk_size*i:chunk_size*(i+1)+K-1]
out_nhwci = out_nhwc[:1,chunk_size*i:chunk_size*(i+1)]
if residual is not None:
if i == chunks-1:
ri = residual[:1,chunk_size*i:]
else:
ri = residual[:1,chunk_size*i:chunk_size*(i+1)]
else:
ri = None
out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri))
else:
out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual)
return out_nhwc
class Upsample2D(nn.Module):
def __init__(self,
channels,
use_conv=False,
use_conv_transpose=False,
out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
else:
assert "Not Supported"
self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
def forward(self, x, output_size=None):
assert x.shape[-1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if output_size is None:
x = F.interpolate(
x.permute(0,3,1,2).to(memory_format=torch.channels_last),
scale_factor=2.0, mode='nearest').permute(0,2,3,1).contiguous()
else:
x = F.interpolate(
x.permute(0,3,1,2).to(memory_format=torch.channels_last),
size=output_size, mode='nearest').permute(0,2,3,1).contiguous()
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class Downsample2D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[-1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 0, 0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
assert x.shape[-1] == self.channels
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class CausalConv(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size,
**kwargs
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else ((kernel_size,) * 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.dilation = kwargs.pop('dilation', 1)
self.stride = kwargs.pop('stride', 1)
if isinstance(self.stride, int):
self.stride = (self.stride, 1, 1)
time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
self.is_first_run = True
def forward(self, x, is_init=True, residual=None):
x = nn.functional.pad(x,
self.time_causal_padding if is_init else self.time_uncausal_padding)
x = self.conv(x)
if residual is not None:
x.add_(residual)
return x
class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert out_channels * factor**3 % in_channels == 0
self.repeats = out_channels * factor**3 // in_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(x.size(0), self.out_channels, self.factor, self.factor, self.factor, x.size(2), x.size(3), x.size(4))
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(x.size(0), self.out_channels, x.size(2)*self.factor, x.size(4)*self.factor, x.size(6)*self.factor)
x = x[:, :, self.factor - 1:, :, :]
return x
class ConvPixelShuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
):
super().__init__()
self.factor = factor
out_ratio = factor**3
self.conv = CausalConv(
in_channels,
out_channels * out_ratio,
kernel_size=kernel_size
)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_shuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
batch_size, channels, depth, height, width = x.size()
new_channels = channels // (factor ** 3)
new_depth = depth * factor
new_height = height * factor
new_width = width * factor
x = x.view(batch_size, new_channels, factor, factor, factor, depth, height, width)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(batch_size, new_channels, new_depth, new_height, new_width)
x = x[:, :, factor - 1:, :, :]
return x
class ConvPixelUnshuffleDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
):
super().__init__()
self.factor = factor
out_ratio = factor**3
assert out_channels % out_ratio == 0
self.conv = CausalConv(
in_channels,
out_channels // out_ratio,
kernel_size=kernel_size
)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_unshuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
pad = (0, 0, 0, 0, factor-1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * factor**3, D // factor, H // factor, W // factor)
return x
class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert in_channels * factor**3 % out_channels == 0
self.group_size = in_channels * factor**3 // out_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor)
x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor)
x = x.mean(dim=2)
return x
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert in_channels * factor**3 % out_channels == 0
self.group_size = in_channels * factor**3 // out_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor)
x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor)
x = x.mean(dim=2)
return x
def base_group_norm_with_zero_pad(x, norm_layer, act_silu=True, pad_size=2):
out_shape = list(x.shape)
out_shape[1] += pad_size
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out[:, pad_size:] = base_group_norm(x, norm_layer, act_silu=act_silu, channel_last=True)
out[:, :pad_size] = 0
return out
class CausalConvChannelLast(CausalConv):
def __init__(self,
chan_in,
chan_out,
kernel_size,
**kwargs
):
super().__init__(
chan_in, chan_out, kernel_size, **kwargs)
self.time_causal_padding = (0, 0) + self.time_causal_padding
self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding
def forward(self, x, is_init=True, residual=None):
if self.is_first_run:
self.is_first_run = False
# self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous())
x = nn.functional.pad(x,
self.time_causal_padding if is_init else self.time_uncausal_padding)
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class CausalConvAfterNorm(CausalConv):
def __init__(self,
chan_in,
chan_out,
kernel_size,
**kwargs
):
super().__init__(
chan_in, chan_out, kernel_size, **kwargs)
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, padding=(0, 1, 1), **kwargs)
else:
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
self.is_first_run = True
def forward(self, x, is_init=True, residual=None):
if self.is_first_run:
self.is_first_run = False
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
pass
else:
x = nn.functional.pad(x, self.time_causal_padding).contiguous()
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class AttnBlock(nn.Module):
def __init__(self,
in_channels
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
def attention(self, x, is_init=True):
x = base_group_norm(x, self.norm, act_silu=False, channel_last=True)
q = self.q(x, is_init)
k = self.k(x, is_init)
v = self.v(x, is_init)
b, t, h, w, c = q.shape
q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v))
x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w)
return x
def forward(self, x):
x = x.permute(0,2,3,4,1).contiguous()
h = self.attention(x)
x = self.proj_out(h, residual=x)
x = x.permute(0,4,1,2,3)
return x
class Resnet3DBlock(nn.Module):
def __init__(self,
in_channels,
out_channels=None,
temb_channels=512,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
assert conv_shortcut is False
self.use_conv_shortcut = conv_shortcut
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
else:
self.nin_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=1)
def forward(self, x, temb=None, is_init=True):
x = x.permute(0,2,3,4,1).contiguous()
h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2)
x = self.conv2(h, residual=x)
x = x.permute(0,4,1,2,3)
return x
class Downsample3D(nn.Module):
def __init__(self,
in_channels,
with_conv,
stride
):
super().__init__()
self.with_conv = with_conv
if with_conv:
self.conv = CausalConv(in_channels, in_channels, kernel_size=3, stride=stride)
def forward(self, x, is_init=True):
if self.with_conv:
x = self.conv(x, is_init)
else:
x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
return x
class VideoEncoder(nn.Module):
def __init__(self,
ch=32,
ch_mult=(4, 8, 16, 16),
num_res_blocks=2,
in_channels=3,
z_channels=16,
double_z=True,
down_sampling_layer=[1, 2],
resamp_with_conv=True,
version=1,
):
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = CausalConv(in_channels, ch, kernel_size=3)
self.down_sampling_layer = down_sampling_layer
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level in self.down_sampling_layer:
down.downsample = Downsample3D(block_in, resamp_with_conv, stride=(2, 2, 2))
else:
down.downsample = Downsample2D(block_in, resamp_with_conv, padding=0) #DIFF
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.version = version
if version == 2:
channels = 4 * z_channels * 2 ** 3
self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D(block_in, channels, kernel_size=3, factor=2)
self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D(block_in, channels, 2)
self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D(channels, 2 * z_channels if double_z else z_channels, 1)
self.conv_out = CausalConvChannelLast(channels, 2 * z_channels if double_z else z_channels, kernel_size=3)
else:
self.conv_out = CausalConvAfterNorm(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3)
@torch.inference_mode()
def forward(self, x, video_frame_num, is_init=True):
# timestep embedding
temb = None
t = video_frame_num
# downsampling
h = self.conv_in(x, is_init)
# make it real channel last, but behave like normal layout
h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb, is_init)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
if isinstance(self.down[i_level].downsample, Downsample2D):
_, _, t, _, _ = h.shape
h = rearrange(h, "b c t h w -> (b t) h w c", t=t)
h = self.down[i_level].downsample(h)
h = rearrange(h, "(b t) h w c -> b c t h w", t=t)
else:
h = self.down[i_level].downsample(h, is_init)
h = self.mid.block_1(h, temb, is_init)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb, is_init)
h = h.permute(0,2,3,4,1).contiguous() # b c l h w -> b l h w c
if self.version == 2:
h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True)
h = h.permute(0,4,1,2,3).contiguous()
shortcut = self.shortcut_pathify(h, is_init)
h = self.conv_patchify(h, is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_out(h, is_init).permute(0,2,3,4,1)
h = self.conv_out(h.permute(0,2,3,4,1).contiguous(), is_init)
h = h.add_(shortcut)
else:
h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
h = self.conv_out(h, is_init)
h = h.permute(0,4,1,2,3) # b l h w c -> b c l h w
h = rearrange(h, "b c t h w -> b t c h w")
return h
class Res3DBlockUpsample(nn.Module):
def __init__(self,
input_filters,
num_filters,
down_sampling_stride,
down_sampling=False
):
super().__init__()
self.input_filters = input_filters
self.num_filters = num_filters
self.act_ = nn.SiLU(inplace=True)
self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm1 = nn.GroupNorm(32, num_filters)
self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm2 = nn.GroupNorm(32, num_filters)
self.down_sampling = down_sampling
if down_sampling:
self.down_sampling_stride = down_sampling_stride
else:
self.down_sampling_stride = [1, 1, 1]
if num_filters != input_filters or down_sampling:
self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride)
self.norm3 = nn.GroupNorm(32, num_filters)
def forward(self, x, is_init=False):
x = x.permute(0,2,3,4,1).contiguous()
residual = x
h = self.conv1(x, is_init)
h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True)
h = self.conv2(h, is_init)
h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True)
if self.down_sampling or self.num_filters != self.input_filters:
x = self.conv3(x, is_init)
x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True)
h.add_(x)
h = self.act_(h)
if residual is not None:
h.add_(residual)
h = h.permute(0,4,1,2,3)
return h
class Upsample3D(nn.Module):
def __init__(self,
in_channels,
scale_factor=2
):
super().__init__()
self.scale_factor = scale_factor
self.conv3d = Res3DBlockUpsample(input_filters=in_channels,
num_filters=in_channels,
down_sampling_stride=(1, 1, 1),
down_sampling=False)
def forward(self, x, is_init=True, is_split=True):
b, c, t, h, w = x.shape
# x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d)
if is_split:
split_size = c // 8
x_slices = torch.split(x, split_size, dim=1)
x = [nn.functional.interpolate(x, scale_factor=self.scale_factor) for x in x_slices]
x = torch.cat(x, dim=1)
else:
x = nn.functional.interpolate(x, scale_factor=self.scale_factor)
x = self.conv3d(x, is_init)
return x
class VideoDecoder(nn.Module):
def __init__(self,
ch=128,
z_channels=16,
out_channels=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
temporal_up_layers=[2, 3],
temporal_downsample=4,
resamp_with_conv=True,
version=1,
):
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_downsample = temporal_downsample
block_in = ch * ch_mult[self.num_resolutions - 1]
self.version = version
if version == 2:
channels = 4 * z_channels * 2 ** 3
self.conv_in = CausalConv(z_channels, channels, kernel_size=3)
self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(z_channels, channels, 1)
self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D(channels, block_in, kernel_size=3, factor=2)
self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(channels, block_in, 2)
else:
self.conv_in = CausalConv(z_channels, block_in, kernel_size=3)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
# upsampling
self.up_id = len(temporal_up_layers)
self.video_frame_num = 1
self.cur_video_frame_num = self.video_frame_num // 2 ** self.up_id + 1
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level in temporal_up_layers:
up.upsample = Upsample3D(block_in)
self.cur_video_frame_num = self.cur_video_frame_num * 2
else:
up.upsample = Upsample2D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3)
@torch.inference_mode()
def forward(self, z, is_init=True):
z = rearrange(z, "b t c h w -> b c t h w")
h = self.conv_in(z, is_init=is_init)
if self.version == 2:
shortcut = self.shortcut_in(z, is_init=is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_unpathify(h, is_init=is_init)
h = self.conv_unpatchify(h, is_init=is_init)
h = h.add_(shortcut)
temb = None
h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
h = self.mid.block_1(h, temb, is_init=is_init)
h = self.mid.attn_1(h)
h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
h = self.mid.block_2(h, temb, is_init=is_init)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
h = self.up[i_level].block[i_block](h, temb, is_init=is_init)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
if isinstance(self.up[i_level].upsample, Upsample2D):
B = h.size(0)
h = h.permute(0,2,3,4,1).flatten(0,1)
h = self.up[i_level].upsample(h)
h = h.unflatten(0, (B, -1)).permute(0,4,1,2,3)
else:
h = self.up[i_level].upsample(h, is_init=is_init)
# end
h = h.permute(0,2,3,4,1) # b c l h w -> b l h w c
h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
h = self.conv_out(h)
h = h.permute(0,4,1,2,3)
if is_init:
h = h[:, :, (self.temporal_downsample - 1):]
return h
def rms_norm(input, normalized_shape, eps=1e-6):
dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).flatten(-len(normalized_shape)).mean(-1)[(...,) + (None,) * len(normalized_shape)]
input = input * torch.rsqrt(variance + eps)
return input.to(dtype)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False, rms_norm_mean=False, only_return_mean=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) #N,[X],C,H,W
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
self.deterministic = deterministic
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean,
device=self.parameters.device,
dtype=self.parameters.dtype)
if rms_norm_mean:
self.mean = rms_norm(self.mean, self.mean.size()[1:])
self.only_return_mean = only_return_mean
def sample(self, generator=None):
# make sure sample is on the same device
# as the parameters and has same dtype
sample = torch.randn(
self.mean.shape, generator=generator, device=self.parameters.device)
sample = sample.to(dtype=self.parameters.dtype)
x = self.mean + self.std * sample
if self.only_return_mean:
return self.mean
else:
return x
class AutoencoderKL(nn.Module):
@with_empty_init
def __init__(self,
in_channels=3,
out_channels=3,
z_channels=16,
num_res_blocks=2,
model_path=None,
weight_dict={},
world_size=1,
version=1,
):
super().__init__()
self.frame_len = 17
self.latent_len = 3 if version == 2 else 5
base_group_norm.spatial = True if version == 2 else False
self.encoder = VideoEncoder(
in_channels=in_channels,
z_channels=z_channels,
num_res_blocks=num_res_blocks,
version=version,
)
self.decoder = VideoDecoder(
z_channels=z_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
version=version,
)
if model_path is not None:
weight_dict = self.init_from_ckpt(model_path)
if len(weight_dict) != 0:
self.load_from_dict(weight_dict)
self.convert_channel_last()
self.world_size = world_size
def init_from_ckpt(self, model_path):
from safetensors import safe_open
p = {}
with safe_open(model_path, framework="pt", device="cpu") as f:
for k in f.keys():
tensor = f.get_tensor(k)
if k.startswith("decoder.conv_out."):
k = k.replace("decoder.conv_out.", "decoder.conv_out.conv.")
p[k] = tensor
return p
def load_from_dict(self, p):
self.load_state_dict(p)
def convert_channel_last(self):
#Conv2d NCHW->NHWC
pass
def naive_encode(self, x, is_init_image=True):
b, l, c, h, w = x.size()
x = rearrange(x, 'b l c h w -> b c l h w').contiguous()
z = self.encoder(x, l, True) # 下采样[1, 4, 8, 16, 16]
return z
@torch.inference_mode()
def encode(self, x):
# b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w
chunks = list(x.split(self.frame_len, dim=1))
for i in range(len(chunks)):
chunks[i] = self.naive_encode(chunks[i], True)
z = torch.cat(chunks, dim=1)
posterior = DiagonalGaussianDistribution(z)
return posterior.sample()
def decode_naive(self, z, is_init=True):
z = z.to(next(self.decoder.parameters()).dtype)
dec = self.decoder(z, is_init)
return dec
@torch.inference_mode()
def decode(self, z):
# b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
chunks = list(z.split(self.latent_len, dim=1))
if self.world_size > 1:
chunks_total_num = len(chunks)
max_num_per_rank = (chunks_total_num + self.world_size - 1) // self.world_size
rank = torch.distributed.get_rank()
chunks_ = chunks[max_num_per_rank * rank : max_num_per_rank * (rank + 1)]
if len(chunks_) < max_num_per_rank:
chunks_.extend(chunks[:max_num_per_rank-len(chunks_)])
chunks = chunks_
for i in range(len(chunks)):
chunks[i] = self.decode_naive(chunks[i], True).permute(0,2,1,3,4)
x = torch.cat(chunks, dim=1)
if self.world_size > 1:
x_ = torch.empty([x.size(0), (self.world_size * max_num_per_rank) * self.frame_len, *x.shape[2:]], dtype=x.dtype, device=x.device)
torch.distributed.all_gather_into_tensor(x_, x)
x = x_[:, : chunks_total_num * self.frame_len]
x = self.mix(x)
return x
def mix(self, x):
remain_scale = 0.6
mix_scale = 1. - remain_scale
front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
back = slice(self.frame_len, x.size(1), self.frame_len)
x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale
x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment