Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
import os
import imageio
import numpy as np
import torch
class VideoProcessor:
def __init__(self, save_path: str = './results', name_suffix: str = ''):
self.save_path = save_path
os.makedirs(self.save_path, exist_ok=True)
self.name_suffix = name_suffix
def crop2standard540p(self, vid_array):
_, height, width, _ = vid_array.shape
height_center = height // 2
width_center = width // 2
if width_center > height_center: ## horizon mode
return vid_array[:, height_center - 270:height_center + 270, width_center - 480:width_center + 480]
elif width_center < height_center: ## portrait mode
return vid_array[:, height_center - 480:height_center + 480, width_center - 270:width_center + 270]
else:
return vid_array
def save_imageio_video(self, video_array: np.array, output_filename: str, fps=25, codec='libx264'):
ffmpeg_params = [
"-vf",
"atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1", # denoise
]
with imageio.get_writer(output_filename, fps=fps, codec=codec, ffmpeg_params=ffmpeg_params) as vid_writer:
for img_array in video_array:
vid_writer.append_data(img_array)
def postprocess_video(self, video_tensor, output_file_name='', output_type="mp4", crop2standard540p=True):
if len(self.name_suffix) == 0:
video_path = os.path.join(self.save_path, f"{output_file_name}.{output_type}")
else:
video_path = os.path.join(self.save_path, f"{output_file_name}-{self.name_suffix}.{output_type}")
video_tensor = torch.cat([t for t in video_tensor], dim=-2)
video_tensor = (video_tensor.cpu().clamp(-1, 1) + 1) * 127.5
video_array = video_tensor.clamp(0, 255).to(torch.uint8).numpy().transpose(0, 2, 3, 1)
if crop2standard540p:
video_array = self.crop2standard540p(video_array)
self.save_imageio_video(video_array, video_path)
print(f"Saved the generated video in {video_path}")
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from fastvideo.models.stepvideo.utils import with_empty_init
def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
if hasattr(base_group_norm, 'spatial') and base_group_norm.spatial:
assert channel_last
x_shape = x.shape
x = x.flatten(0, 1)
if channel_last:
# Permute to NCHW format
x = x.permute(0, 3, 1, 2)
out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 1)
out = out.view(x_shape)
else:
if channel_last:
# Permute to NCHW format
x = x.permute(0, 3, 1, 2)
out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 1)
return out
def base_conv2d(x, conv_layer, channel_last=False, residual=None):
if channel_last:
x = x.permute(0, 3, 1, 2) # NHWC to NCHW
out = F.conv2d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
if residual is not None:
if channel_last:
residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 1) # NCHW to NHWC
return out
def base_conv3d(x, conv_layer, channel_last=False, residual=None, only_return_output=False):
if only_return_output:
size = cal_outsize(x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding)
return torch.empty(size, device=x.device, dtype=x.dtype)
if channel_last:
x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out = F.conv3d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
if residual is not None:
if channel_last:
residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC
return out
def cal_outsize(input_sizes, kernel_sizes, stride, padding):
stride_d, stride_h, stride_w = stride
padding_d, padding_h, padding_w = padding
dilation_d, dilation_h, dilation_w = 1, 1, 1
in_d = input_sizes[1]
in_h = input_sizes[2]
in_w = input_sizes[3]
kernel_d = kernel_sizes[2]
kernel_h = kernel_sizes[3]
kernel_w = kernel_sizes[4]
out_channels = kernel_sizes[0]
out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d)
out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h)
out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w)
size = [input_sizes[0], out_d, out_h, out_w, out_channels]
return size
def calc_out_(in_size, padding, dilation, kernel, stride):
return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1
def base_conv3d_channel_last(x, conv_layer, residual=None):
in_numel = x.numel()
out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels)
if (in_numel >= 2**30) or (out_numel >= 2**30):
assert conv_layer.stride[0] == 1, "time split asks time stride = 1"
B, T, H, W, C = x.shape
K = conv_layer.kernel_size[0]
chunks = 4
chunk_size = T // chunks
if residual is None:
out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual, only_return_output=True)
else:
out_nhwc = residual
assert B == 1
for i in range(chunks):
if i == chunks - 1:
xi = x[:1, chunk_size * i:]
out_nhwci = out_nhwc[:1, chunk_size * i:]
else:
xi = x[:1, chunk_size * i:chunk_size * (i + 1) + K - 1]
out_nhwci = out_nhwc[:1, chunk_size * i:chunk_size * (i + 1)]
if residual is not None:
if i == chunks - 1:
ri = residual[:1, chunk_size * i:]
else:
ri = residual[:1, chunk_size * i:chunk_size * (i + 1)]
else:
ri = None
out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri))
else:
out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual)
return out_nhwc
class Upsample2D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
else:
assert "Not Supported"
self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
def forward(self, x, output_size=None):
assert x.shape[-1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if output_size is None:
x = F.interpolate(x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last),
scale_factor=2.0,
mode='nearest').permute(0, 2, 3, 1).contiguous()
else:
x = F.interpolate(x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last),
size=output_size,
mode='nearest').permute(0, 2, 3, 1).contiguous()
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class Downsample2D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[-1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 0, 0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
assert x.shape[-1] == self.channels
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class CausalConv(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, **kwargs):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else ((kernel_size, ) * 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.dilation = kwargs.pop('dilation', 1)
self.stride = kwargs.pop('stride', 1)
if isinstance(self.stride, int):
self.stride = (self.stride, 1, 1)
time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
self.is_first_run = True
def forward(self, x, is_init=True, residual=None):
x = nn.functional.pad(x, self.time_causal_padding if is_init else self.time_uncausal_padding)
x = self.conv(x)
if residual is not None:
x.add_(residual)
return x
class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert out_channels * factor**3 % in_channels == 0
self.repeats = out_channels * factor**3 // in_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(x.size(0), self.out_channels, self.factor, self.factor, self.factor, x.size(2), x.size(3), x.size(4))
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(x.size(0), self.out_channels,
x.size(2) * self.factor,
x.size(4) * self.factor,
x.size(6) * self.factor)
x = x[:, :, self.factor - 1:, :, :]
return x
class ConvPixelShuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
):
super().__init__()
self.factor = factor
out_ratio = factor**3
self.conv = CausalConv(in_channels, out_channels * out_ratio, kernel_size=kernel_size)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_shuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
batch_size, channels, depth, height, width = x.size()
new_channels = channels // (factor**3)
new_depth = depth * factor
new_height = height * factor
new_width = width * factor
x = x.view(batch_size, new_channels, factor, factor, factor, depth, height, width)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(batch_size, new_channels, new_depth, new_height, new_width)
x = x[:, :, factor - 1:, :, :]
return x
class ConvPixelUnshuffleDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
):
super().__init__()
self.factor = factor
out_ratio = factor**3
assert out_channels % out_ratio == 0
self.conv = CausalConv(in_channels, out_channels // out_ratio, kernel_size=kernel_size)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_unshuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
pad = (0, 0, 0, 0, factor - 1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * factor**3, D // factor, H // factor, W // factor)
return x
class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert in_channels * factor**3 % out_channels == 0
self.group_size = in_channels * factor**3 // out_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
pad = (0, 0, 0, 0, self.factor - 1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor)
x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor)
x = x.mean(dim=2)
return x
def base_group_norm_with_zero_pad(x, norm_layer, act_silu=True, pad_size=2):
out_shape = list(x.shape)
out_shape[1] += pad_size
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out[:, pad_size:] = base_group_norm(x, norm_layer, act_silu=act_silu, channel_last=True)
out[:, :pad_size] = 0
return out
class CausalConvChannelLast(CausalConv):
def __init__(self, chan_in, chan_out, kernel_size, **kwargs):
super().__init__(chan_in, chan_out, kernel_size, **kwargs)
self.time_causal_padding = (0, 0) + self.time_causal_padding
self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding
def forward(self, x, is_init=True, residual=None):
if self.is_first_run:
self.is_first_run = False
# self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous())
x = nn.functional.pad(x, self.time_causal_padding if is_init else self.time_uncausal_padding)
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class CausalConvAfterNorm(CausalConv):
def __init__(self, chan_in, chan_out, kernel_size, **kwargs):
super().__init__(chan_in, chan_out, kernel_size, **kwargs)
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
self.conv = nn.Conv3d(chan_in,
chan_out,
kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=(0, 1, 1),
**kwargs)
else:
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
self.is_first_run = True
def forward(self, x, is_init=True, residual=None):
if self.is_first_run:
self.is_first_run = False
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
pass
else:
x = nn.functional.pad(x, self.time_causal_padding).contiguous()
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
def attention(self, x, is_init=True):
x = base_group_norm(x, self.norm, act_silu=False, channel_last=True)
q = self.q(x, is_init)
k = self.k(x, is_init)
v = self.v(x, is_init)
b, t, h, w, c = q.shape
q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v))
x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w)
return x
def forward(self, x):
x = x.permute(0, 2, 3, 4, 1).contiguous()
h = self.attention(x)
x = self.proj_out(h, residual=x)
x = x.permute(0, 4, 1, 2, 3)
return x
class Resnet3DBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
temb_channels=512,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
assert conv_shortcut is False
self.use_conv_shortcut = conv_shortcut
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
else:
self.nin_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=1)
def forward(self, x, temb=None, is_init=True):
x = x.permute(0, 2, 3, 4, 1).contiguous()
h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2)
x = self.conv2(h, residual=x)
x = x.permute(0, 4, 1, 2, 3)
return x
class Downsample3D(nn.Module):
def __init__(self, in_channels, with_conv, stride):
super().__init__()
self.with_conv = with_conv
if with_conv:
self.conv = CausalConv(in_channels, in_channels, kernel_size=3, stride=stride)
def forward(self, x, is_init=True):
if self.with_conv:
x = self.conv(x, is_init)
else:
x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
return x
class VideoEncoder(nn.Module):
def __init__(
self,
ch=32,
ch_mult=(4, 8, 16, 16),
num_res_blocks=2,
in_channels=3,
z_channels=16,
double_z=True,
down_sampling_layer=[1, 2],
resamp_with_conv=True,
version=1,
):
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = CausalConv(in_channels, ch, kernel_size=3)
self.down_sampling_layer = down_sampling_layer
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level in self.down_sampling_layer:
down.downsample = Downsample3D(block_in, resamp_with_conv, stride=(2, 2, 2))
else:
down.downsample = Downsample2D(block_in, resamp_with_conv, padding=0) #DIFF
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.version = version
if version == 2:
channels = 4 * z_channels * 2**3
self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D(block_in, channels, kernel_size=3, factor=2)
self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D(block_in, channels, 2)
self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D(
channels, 2 * z_channels if double_z else z_channels, 1)
self.conv_out = CausalConvChannelLast(channels, 2 * z_channels if double_z else z_channels, kernel_size=3)
else:
self.conv_out = CausalConvAfterNorm(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3)
@torch.inference_mode()
def forward(self, x, video_frame_num, is_init=True):
# timestep embedding
temb = None
t = video_frame_num
# downsampling
h = self.conv_in(x, is_init)
# make it real channel last, but behave like normal layout
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb, is_init)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
if isinstance(self.down[i_level].downsample, Downsample2D):
_, _, t, _, _ = h.shape
h = rearrange(h, "b c t h w -> (b t) h w c", t=t)
h = self.down[i_level].downsample(h)
h = rearrange(h, "(b t) h w c -> b c t h w", t=t)
else:
h = self.down[i_level].downsample(h, is_init)
h = self.mid.block_1(h, temb, is_init)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb, is_init)
h = h.permute(0, 2, 3, 4, 1).contiguous() # b c l h w -> b l h w c
if self.version == 2:
h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True)
h = h.permute(0, 4, 1, 2, 3).contiguous()
shortcut = self.shortcut_pathify(h, is_init)
h = self.conv_patchify(h, is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_out(h, is_init).permute(0, 2, 3, 4, 1)
h = self.conv_out(h.permute(0, 2, 3, 4, 1).contiguous(), is_init)
h = h.add_(shortcut)
else:
h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
h = self.conv_out(h, is_init)
h = h.permute(0, 4, 1, 2, 3) # b l h w c -> b c l h w
h = rearrange(h, "b c t h w -> b t c h w")
return h
class Res3DBlockUpsample(nn.Module):
def __init__(self, input_filters, num_filters, down_sampling_stride, down_sampling=False):
super().__init__()
self.input_filters = input_filters
self.num_filters = num_filters
self.act_ = nn.SiLU(inplace=True)
self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm1 = nn.GroupNorm(32, num_filters)
self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm2 = nn.GroupNorm(32, num_filters)
self.down_sampling = down_sampling
if down_sampling:
self.down_sampling_stride = down_sampling_stride
else:
self.down_sampling_stride = [1, 1, 1]
if num_filters != input_filters or down_sampling:
self.conv3 = CausalConvChannelLast(input_filters,
num_filters,
kernel_size=[1, 1, 1],
stride=self.down_sampling_stride)
self.norm3 = nn.GroupNorm(32, num_filters)
def forward(self, x, is_init=False):
x = x.permute(0, 2, 3, 4, 1).contiguous()
residual = x
h = self.conv1(x, is_init)
h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True)
h = self.conv2(h, is_init)
h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True)
if self.down_sampling or self.num_filters != self.input_filters:
x = self.conv3(x, is_init)
x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True)
h.add_(x)
h = self.act_(h)
if residual is not None:
h.add_(residual)
h = h.permute(0, 4, 1, 2, 3)
return h
class Upsample3D(nn.Module):
def __init__(self, in_channels, scale_factor=2):
super().__init__()
self.scale_factor = scale_factor
self.conv3d = Res3DBlockUpsample(input_filters=in_channels,
num_filters=in_channels,
down_sampling_stride=(1, 1, 1),
down_sampling=False)
def forward(self, x, is_init=True, is_split=True):
b, c, t, h, w = x.shape
# x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d)
if is_split:
split_size = c // 8
x_slices = torch.split(x, split_size, dim=1)
x = [nn.functional.interpolate(x, scale_factor=self.scale_factor) for x in x_slices]
x = torch.cat(x, dim=1)
else:
x = nn.functional.interpolate(x, scale_factor=self.scale_factor)
x = self.conv3d(x, is_init)
return x
class VideoDecoder(nn.Module):
def __init__(
self,
ch=128,
z_channels=16,
out_channels=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
temporal_up_layers=[2, 3],
temporal_downsample=4,
resamp_with_conv=True,
version=1,
):
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_downsample = temporal_downsample
block_in = ch * ch_mult[self.num_resolutions - 1]
self.version = version
if version == 2:
channels = 4 * z_channels * 2**3
self.conv_in = CausalConv(z_channels, channels, kernel_size=3)
self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(z_channels, channels, 1)
self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D(channels, block_in, kernel_size=3, factor=2)
self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(channels, block_in, 2)
else:
self.conv_in = CausalConv(z_channels, block_in, kernel_size=3)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
# upsampling
self.up_id = len(temporal_up_layers)
self.video_frame_num = 1
self.cur_video_frame_num = self.video_frame_num // 2**self.up_id + 1
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level in temporal_up_layers:
up.upsample = Upsample3D(block_in)
self.cur_video_frame_num = self.cur_video_frame_num * 2
else:
up.upsample = Upsample2D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3)
@torch.inference_mode()
def forward(self, z, is_init=True):
z = rearrange(z, "b t c h w -> b c t h w")
h = self.conv_in(z, is_init=is_init)
if self.version == 2:
shortcut = self.shortcut_in(z, is_init=is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_unpathify(h, is_init=is_init)
h = self.conv_unpatchify(h, is_init=is_init)
h = h.add_(shortcut)
temb = None
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.mid.block_1(h, temb, is_init=is_init)
h = self.mid.attn_1(h)
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.mid.block_2(h, temb, is_init=is_init)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.up[i_level].block[i_block](h, temb, is_init=is_init)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
if isinstance(self.up[i_level].upsample, Upsample2D):
B = h.size(0)
h = h.permute(0, 2, 3, 4, 1).flatten(0, 1)
h = self.up[i_level].upsample(h)
h = h.unflatten(0, (B, -1)).permute(0, 4, 1, 2, 3)
else:
h = self.up[i_level].upsample(h, is_init=is_init)
# end
h = h.permute(0, 2, 3, 4, 1) # b c l h w -> b l h w c
h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
h = self.conv_out(h)
h = h.permute(0, 4, 1, 2, 3)
if is_init:
h = h[:, :, (self.temporal_downsample - 1):]
return h
def rms_norm(input, normalized_shape, eps=1e-6):
dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).flatten(-len(normalized_shape)).mean(-1)[(..., ) + (None, ) * len(normalized_shape)]
input = input * torch.rsqrt(variance + eps)
return input.to(dtype)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False, rms_norm_mean=False, only_return_mean=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) #N,[X],C,H,W
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
self.deterministic = deterministic
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean,
device=self.parameters.device,
dtype=self.parameters.dtype)
if rms_norm_mean:
self.mean = rms_norm(self.mean, self.mean.size()[1:])
self.only_return_mean = only_return_mean
def sample(self, generator=None):
# make sure sample is on the same device
# as the parameters and has same dtype
sample = torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
sample = sample.to(dtype=self.parameters.dtype)
x = self.mean + self.std * sample
if self.only_return_mean:
return self.mean
else:
return x
class AutoencoderKL(nn.Module):
@with_empty_init
def __init__(
self,
in_channels=3,
out_channels=3,
z_channels=16,
num_res_blocks=2,
model_path=None,
weight_dict={},
world_size=1,
version=1,
):
super().__init__()
self.frame_len = 17
self.latent_len = 3 if version == 2 else 5
base_group_norm.spatial = True if version == 2 else False
self.encoder = VideoEncoder(
in_channels=in_channels,
z_channels=z_channels,
num_res_blocks=num_res_blocks,
version=version,
)
self.decoder = VideoDecoder(
z_channels=z_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
version=version,
)
if model_path is not None:
weight_dict = self.init_from_ckpt(model_path)
if len(weight_dict) != 0:
self.load_from_dict(weight_dict)
self.convert_channel_last()
self.world_size = world_size
def init_from_ckpt(self, model_path):
from safetensors import safe_open
p = {}
with safe_open(model_path, framework="pt", device="cpu") as f:
for k in f.keys():
tensor = f.get_tensor(k)
if k.startswith("decoder.conv_out."):
k = k.replace("decoder.conv_out.", "decoder.conv_out.conv.")
p[k] = tensor
return p
def load_from_dict(self, p):
self.load_state_dict(p)
def convert_channel_last(self):
#Conv2d NCHW->NHWC
pass
def naive_encode(self, x, is_init_image=True):
b, len, c, h, w = x.size()
x = rearrange(x, 'b l c h w -> b c l h w').contiguous()
z = self.encoder(x, len, True) # 下采样[1, 4, 8, 16, 16]
return z
@torch.inference_mode()
def encode(self, x):
# b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w
chunks = list(x.split(self.frame_len, dim=1))
for i in range(len(chunks)):
chunks[i] = self.naive_encode(chunks[i], True)
z = torch.cat(chunks, dim=1)
posterior = DiagonalGaussianDistribution(z)
return posterior.sample()
def decode_naive(self, z, is_init=True):
z = z.to(next(self.decoder.parameters()).dtype)
dec = self.decoder(z, is_init)
return dec
@torch.inference_mode()
def decode(self, z):
# b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
chunks = list(z.split(self.latent_len, dim=1))
if self.world_size > 1:
chunks_total_num = len(chunks)
max_num_per_rank = (chunks_total_num + self.world_size - 1) // self.world_size
rank = torch.distributed.get_rank()
chunks_ = chunks[max_num_per_rank * rank:max_num_per_rank * (rank + 1)]
if len(chunks_) < max_num_per_rank:
chunks_.extend(chunks[:max_num_per_rank - len(chunks_)])
chunks = chunks_
for i in range(len(chunks)):
chunks[i] = self.decode_naive(chunks[i], True).permute(0, 2, 1, 3, 4)
x = torch.cat(chunks, dim=1)
if self.world_size > 1:
x_ = torch.empty([x.size(0), (self.world_size * max_num_per_rank) * self.frame_len, *x.shape[2:]],
dtype=x.dtype,
device=x.device)
torch.distributed.all_gather_into_tensor(x_, x)
x = x_[:, :chunks_total_num * self.frame_len]
x = self.mix(x)
return x
def mix(self, x):
remain_scale = 0.6
mix_scale = 1. - remain_scale
front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
back = slice(self.frame_len, x.size(1), self.frame_len)
x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale
x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale
return x
import argparse
import os
import pickle
import threading
import torch
from flask import Blueprint, Flask, Response, request
from flask_restful import Api, Resource
device = f'cuda:{torch.cuda.device_count()-1}'
dtype = torch.bfloat16
def parsed_args():
parser = argparse.ArgumentParser(description="StepVideo API Functions")
parser.add_argument('--model_dir', type=str)
parser.add_argument('--clip_dir', type=str, default='hunyuan_clip')
parser.add_argument('--llm_dir', type=str, default='step_llm')
parser.add_argument('--vae_dir', type=str, default='vae')
parser.add_argument('--port', type=str, default='8080')
args = parser.parse_args()
return args
class StepVaePipeline(Resource):
def __init__(self, vae_dir, version=2):
self.vae = self.build_vae(vae_dir, version)
self.scale_factor = 1.0
def build_vae(self, vae_dir, version=2):
from fastvideo.models.stepvideo.vae.vae import AutoencoderKL
(model_name, z_channels) = ("vae_v2.safetensors", 64) if version == 2 else ("vae.safetensors", 16)
model_path = os.path.join(vae_dir, model_name)
model = AutoencoderKL(
z_channels=z_channels,
model_path=model_path,
version=version,
).to(dtype).to(device).eval()
print("Initialized vae...")
return model
def decode(self, samples, *args, **kwargs):
with torch.no_grad():
try:
dtype = next(self.vae.parameters()).dtype
device = next(self.vae.parameters()).device
samples = self.vae.decode(samples.to(dtype).to(device) / self.scale_factor)
if hasattr(samples, 'sample'):
samples = samples.sample
return samples
except:
torch.cuda.empty_cache()
return None
lock = threading.Lock()
class VAEapi(Resource):
def __init__(self, vae_pipeline):
self.vae_pipeline = vae_pipeline
def get(self):
with lock:
try:
feature = pickle.loads(request.get_data())
feature['api'] = 'vae'
feature = {k: v for k, v in feature.items() if v is not None}
video_latents = self.vae_pipeline.decode(**feature)
response = pickle.dumps(video_latents)
except Exception as e:
print("Caught Exception: ", e)
return Response(e)
return Response(response)
class CaptionPipeline(Resource):
def __init__(self, llm_dir, clip_dir):
self.text_encoder = self.build_llm(llm_dir)
self.clip = self.build_clip(clip_dir)
def build_llm(self, model_dir):
from fastvideo.models.stepvideo.text_encoder.stepllm import STEP1TextEncoder
text_encoder = STEP1TextEncoder(model_dir, max_length=320).to(dtype).to(device).eval()
print("Initialized text encoder...")
return text_encoder
def build_clip(self, model_dir):
from fastvideo.models.stepvideo.text_encoder.clip import HunyuanClip
clip = HunyuanClip(model_dir, max_length=77).to(device).eval()
print("Initialized clip encoder...")
return clip
def embedding(self, prompts, *args, **kwargs):
with torch.no_grad():
try:
y, y_mask = self.text_encoder(prompts)
clip_embedding, _ = self.clip(prompts)
len_clip = clip_embedding.shape[1]
y_mask = torch.nn.functional.pad(y_mask, (len_clip, 0),
value=1) ## pad attention_mask with clip's length
data = {
'y': y.detach().cpu(),
'y_mask': y_mask.detach().cpu(),
'clip_embedding': clip_embedding.to(torch.bfloat16).detach().cpu()
}
return data
except Exception as err:
print(f"{err}")
return None
lock = threading.Lock()
class Captionapi(Resource):
def __init__(self, caption_pipeline):
self.caption_pipeline = caption_pipeline
def get(self):
with lock:
try:
feature = pickle.loads(request.get_data())
feature['api'] = 'caption'
feature = {k: v for k, v in feature.items() if v is not None}
embeddings = self.caption_pipeline.embedding(**feature)
response = pickle.dumps(embeddings)
except Exception as e:
print("Caught Exception: ", e)
return Response(e)
return Response(response)
class RemoteServer(object):
def __init__(self, args) -> None:
self.app = Flask(__name__)
root = Blueprint("root", __name__)
self.app.register_blueprint(root)
api = Api(self.app)
self.vae_pipeline = StepVaePipeline(vae_dir=os.path.join(args.model_dir, args.vae_dir))
api.add_resource(
VAEapi,
"/vae-api",
resource_class_args=[self.vae_pipeline],
)
self.caption_pipeline = CaptionPipeline(llm_dir=os.path.join(args.model_dir, args.llm_dir),
clip_dir=os.path.join(args.model_dir, args.clip_dir))
api.add_resource(
Captionapi,
"/caption-api",
resource_class_args=[self.caption_pipeline],
)
def run(self, host="0.0.0.0", port=8080):
self.app.run(host, port=port, threaded=True, debug=False)
if __name__ == "__main__":
args = parsed_args()
flask_server = RemoteServer(args)
flask_server.run(host="0.0.0.0", port=args.port)
import argparse
import json
import os
import torch
import torch.distributed as dist
from diffusers.utils import export_to_video
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
def generate_video_and_latent(pipe, prompt, height, width, num_frames, num_inference_steps, guidance_scale):
# Set the random seed for reproducibility
generator = torch.Generator("cuda").manual_seed(12345)
# Generate videos from the input prompt
noise, video, latent, prompt_embed, prompt_attention_mask = pipe(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
output_type="latent_and_video",
)
# prompt_embed has negative prompt at index 0
return noise[0], video[0], latent[0], prompt_embed[1], prompt_attention_mask[1]
# return dummy tensor to debug first
# return torch.zeros(1, 3, 480, 848), torch.zeros(1, 256, 16, 16)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=848)
parser.add_argument("--num_inference_steps", type=int, default=64)
parser.add_argument("--guidance_scale", type=float, default=4.5)
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--prompt_path", type=str, default="data/dummyVid/videos2caption.json")
parser.add_argument("--dataset_output_dir", type=str, default="data/dummySynthetic")
args = parser.parse_args()
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size, "local rank", local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
if not isinstance(args.prompt_path, list):
args.prompt_path = [args.prompt_path]
if len(args.prompt_path) == 1 and args.prompt_path[0].endswith("txt"):
text_prompt = open(args.prompt_path[0], "r").readlines()
text_prompt = [i.strip() for i in text_prompt]
pipe = MochiPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload(gpu_id=local_rank)
# make dir if not exist
os.makedirs(args.dataset_output_dir, exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "noise"), exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "video"), exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "latent"), exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "prompt_embed"), exist_ok=True)
os.makedirs(os.path.join(args.dataset_output_dir, "prompt_attention_mask"), exist_ok=True)
data = []
for i, prompt in enumerate(text_prompt):
if i % world_size != local_rank:
continue
(
noise,
video,
latent,
prompt_embed,
prompt_attention_mask,
) = generate_video_and_latent(
pipe,
prompt,
args.height,
args.width,
args.num_frames,
args.num_inference_steps,
args.guidance_scale,
)
# save latent
video_name = str(i)
noise_path = os.path.join(args.dataset_output_dir, "noise", video_name + ".pt")
latent_path = os.path.join(args.dataset_output_dir, "latent", video_name + ".pt")
prompt_embed_path = os.path.join(args.dataset_output_dir, "prompt_embed", video_name + ".pt")
video_path = os.path.join(args.dataset_output_dir, "video", video_name + ".mp4")
prompt_attention_mask_path = os.path.join(args.dataset_output_dir, "prompt_attention_mask", video_name + ".pt")
# save latent
torch.save(noise, noise_path)
torch.save(latent, latent_path)
torch.save(prompt_embed, prompt_embed_path)
torch.save(prompt_attention_mask, prompt_attention_mask_path)
export_to_video(video, video_path, fps=30)
item = {}
item["cap"] = prompt
item["video"] = video_name + ".mp4"
item["noise"] = video_name + ".pt"
item["latent_path"] = video_name + ".pt"
item["prompt_embed_path"] = video_name + ".pt"
item["prompt_attention_mask"] = video_name + ".pt"
data.append(item)
dist.barrier()
local_data = data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
# save json
if local_rank == 0:
all_data = [item for sublist in gathered_data for item in sublist]
with open(os.path.join(args.dataset_output_dir, "videos2caption.json"), "w") as f:
json.dump(all_data, f, indent=4)
import argparse
import os
from pathlib import Path
import imageio
import numpy as np
import torch
import torch.distributed as dist
import torchvision
from einops import rearrange
from fastvideo.models.hunyuan.inference import HunyuanVideoSampler
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def initialize_distributed():
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def main(args):
initialize_distributed()
print(nccl_info.sp_size)
print(args)
models_root_path = Path(args.model_path)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Create save folder to save the samples
save_path = args.output_path
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Load models
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
# Get the updated args
args = hunyuan_video_sampler.args
if args.prompt.endswith('.txt'):
with open(args.prompt) as f:
prompts = [line.strip() for line in f.readlines()]
else:
prompts = [args.prompt]
for prompt in prompts:
outputs = hunyuan_video_sampler.predict(
prompt=prompt,
height=args.height,
width=args.width,
video_length=args.num_frames,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale,
)
videos = rearrange(outputs["samples"], "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=6)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
outputs.append((x * 255).numpy().astype(np.uint8))
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
imageio.mimsave(os.path.join(args.output_path, f"{prompt[:100]}.mp4"), outputs, fps=args.fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic parameters
parser.add_argument("--prompt", type=str, help="prompt file for inference")
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_path", type=str, default="data/hunyuan")
parser.add_argument("--output_path", type=str, default="./outputs/video")
parser.add_argument("--fps", type=int, default=24)
# Additional parameters
parser.add_argument(
"--denoise-type",
type=str,
default="flow",
help="Denoise type for noised inputs.",
)
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument("--neg_prompt", type=str, default=None, help="Negative prompt for sampling.")
parser.add_argument(
"--guidance_scale",
type=float,
default=1.0,
help="Classifier free guidance scale.",
)
parser.add_argument(
"--embedded_cfg_scale",
type=float,
default=6.0,
help="Embedded classifier free guidance scale.",
)
parser.add_argument("--flow_shift", type=int, default=7, help="Flow shift parameter.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference.")
parser.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate per prompt.",
)
parser.add_argument(
"--load-key",
type=str,
default="module",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
parser.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
parser.add_argument(
"--dit-weight",
type=str,
default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
)
parser.add_argument(
"--reproduce",
action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)
# Flow Matching
parser.add_argument(
"--flow-reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
parser.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
parser.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
parser.add_argument(
"--linear-schedule-end",
type=int,
default=25,
help="End step for linear quadratic schedule for flow matching.",
)
# Model parameters
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument("--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument("--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)
parser.add_argument("--vae-sp", action="store_true", default=False)
parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument(
"--text-encoder-precision",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video")
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")
parser.add_argument("--text-encoder-2", type=str, default="clipL")
parser.add_argument(
"--text-encoder-precision-2",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument(
"--enable_torch_compile",
action="store_true",
help="Use torch.compile for speeding up STA inference without teacache",
)
parser.add_argument("--text-states-dim-2", type=int, default=768)
parser.add_argument("--tokenizer-2", type=str, default="clipL")
parser.add_argument("--text-len-2", type=int, default=77)
args = parser.parse_args()
# process for vae sequence parallel
if args.vae_sp and not args.vae_tiling:
raise ValueError("Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.")
main(args)
import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import imageio
import numpy as np
import torch
import torch.distributed as dist
import torchvision
from einops import rearrange
from fastvideo.models.hunyuan.inference import HunyuanVideoSampler
from fastvideo.models.hunyuan.modules.modulate_layers import modulate
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
mask_strategy=None,
output_features=False,
output_features_stride=8,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
guidance=None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if guidance is None:
guidance = torch.tensor([6016.0], device=hidden_states.device, dtype=torch.bfloat16)
img = x = hidden_states
text_mask = encoder_attention_mask
t = timestep
txt = encoder_hidden_states[:, 1:]
text_states_2 = encoder_hidden_states[:, 0, :self.config.text_states_dim_2]
_, _, ot, oh, ow = x.shape # codespell:ignore
tt, th, tw = (
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1], # codespell:ignore
ow // self.patch_size[2], # codespell:ignore
)
original_tt = nccl_info.sp_size * tt
freqs_cos, freqs_sin = self.get_rotary_pos_embed((original_tt, th, tw))
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
if self.enable_teacache:
inp = img.clone()
vec_ = vec.clone()
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
normed_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale)
if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
((modulated_inp - self.previous_modulated_input).abs().mean() /
self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
img += self.previous_residual
else:
ori_img = img.clone()
# --------------------- Pass through DiT blocks ------------------------
for index, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, freqs_cis, text_mask, mask_strategy[index]]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if output_features:
features_list = []
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
(freqs_cos, freqs_sin),
text_mask,
mask_strategy[index + len(self.double_blocks)],
]
x = block(*single_block_args)
if output_features and _ % output_features_stride == 0:
features_list.append(x[:, :img_seq_len, ...])
img = x[:, :img_seq_len, ...]
self.previous_residual = img - ori_img
else:
# --------------------- Pass through DiT blocks ------------------------
for index, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, freqs_cis, text_mask, mask_strategy[index]]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if output_features:
features_list = []
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
(freqs_cos, freqs_sin),
text_mask,
mask_strategy[index + len(self.double_blocks)],
]
x = block(*single_block_args)
if output_features and _ % output_features_stride == 0:
features_list.append(x[:, :img_seq_len, ...])
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
assert not return_dict, "return_dict is not supported."
if output_features:
features_list = torch.stack(features_list, dim=0)
else:
features_list = None
return (img, features_list)
def initialize_distributed():
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def main(args):
initialize_distributed()
print(nccl_info.sp_size)
print(args)
models_root_path = Path(args.model_path)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Create save folder to save the samples
save_path = args.output_path
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Load models
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
# Get the updated args
args = hunyuan_video_sampler.args
# teacache
hunyuan_video_sampler.pipeline.transformer.__class__.enable_teacache = args.enable_teacache
hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.num_inference_steps
hunyuan_video_sampler.pipeline.transformer.__class__.rel_l1_thresh = args.rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
hunyuan_video_sampler.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
hunyuan_video_sampler.pipeline.transformer.__class__.previous_modulated_input = None
hunyuan_video_sampler.pipeline.transformer.__class__.previous_residual = None
hunyuan_video_sampler.pipeline.transformer.__class__.forward = teacache_forward
with open(args.mask_strategy_file_path, 'r') as f:
mask_strategy = json.load(f)
if args.prompt.endswith('.txt'):
with open(args.prompt) as f:
prompts = [line.strip() for line in f.readlines()]
else:
prompts = [args.prompt]
for prompt in prompts:
outputs = hunyuan_video_sampler.predict(
prompt=prompt,
height=args.height,
width=args.width,
video_length=args.num_frames,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale,
mask_strategy=mask_strategy,
)
videos = rearrange(outputs["samples"], "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=6)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
outputs.append((x * 255).numpy().astype(np.uint8))
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
imageio.mimsave(os.path.join(args.output_path, f"{prompt[:100]}.mp4"), outputs, fps=args.fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic parameters
parser.add_argument("--prompt", type=str, help="prompt file for inference")
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_path", type=str, default="data/hunyuan")
parser.add_argument("--output_path", type=str, default="./outputs/video")
parser.add_argument("--fps", type=int, default=24)
# Additional parameters
parser.add_argument(
"--sliding_block_size",
type=str,
default="8,6,10",
help="Sliding block size for sliding block attention.",
)
parser.add_argument(
"--denoise-type",
type=str,
default="flow",
help="Denoise type for noised inputs.",
)
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument("--neg_prompt", type=str, default=None, help="Negative prompt for sampling.")
parser.add_argument(
"--guidance_scale",
type=float,
default=1.0,
help="Classifier free guidance scale.",
)
parser.add_argument(
"--embedded_cfg_scale",
type=float,
default=6.0,
help="Embedded classifier free guidance scale.",
)
parser.add_argument("--flow_shift", type=int, default=7, help="Flow shift parameter.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference.")
parser.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate per prompt.",
)
parser.add_argument(
"--load-key",
type=str,
default="module",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
parser.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
parser.add_argument(
"--dit-weight",
type=str,
default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
)
parser.add_argument(
"--reproduce",
action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)
# Flow Matching
parser.add_argument(
"--flow-reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
parser.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
parser.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
parser.add_argument(
"--linear-schedule-end",
type=int,
default=25,
help="End step for linear quadratic schedule for flow matching.",
)
# Model parameters
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument("--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument("--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)
parser.add_argument("--vae-sp", action="store_true", default=False)
parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument(
"--text-encoder-precision",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video")
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")
parser.add_argument("--text-encoder-2", type=str, default="clipL")
parser.add_argument(
"--text-encoder-precision-2",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument("--text-states-dim-2", type=int, default=768)
parser.add_argument("--tokenizer-2", type=str, default="clipL")
parser.add_argument("--text-len-2", type=int, default=77)
parser.add_argument("--skip_time_steps", type=int, default=10)
parser.add_argument(
"--mask_strategy_selected",
type=lambda x: [int(i) for i in x.strip('[]').split(',')], # Convert string to list of integers
default=[1, 2, 6], # Now can be directly set as a list
help="order of candidates")
parser.add_argument(
"--rel_l1_thresh",
type=float,
default=0.15,
help="0.1 for 1.6x speedup, 0.15 for 2.1x speedup",
)
parser.add_argument(
"--enable_teacache",
action="store_true",
help="Use teacache for speeding up inference",
)
parser.add_argument(
"--enable_torch_compile",
action="store_true",
help="Use torch.compile for speeding up STA inference without teacache",
)
parser.add_argument("--mask_strategy_file_path", type=str, default="assets/mask_strategy.json")
args = parser.parse_args()
# process for vae sequence parallel
if args.vae_sp and not args.vae_tiling:
raise ValueError("Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.")
if args.enable_teacache and args.enable_torch_compile:
raise ValueError(
"--enable_teacache and --enable_torch_compile cannot be used simultaneously. Please enable only one of these options."
)
main(args)
import argparse
import json
import os
import time
import torch
import torch.distributed as dist
from diffusers import BitsAndBytesConfig
from diffusers.utils import export_to_video
from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def initialize_distributed():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def inference(args):
initialize_distributed()
print(nccl_info.sp_size)
device = torch.cuda.current_device()
# Peiyuan: GPU seed will cause A100 and H100 to produce different results .....
weight_dtype = torch.bfloat16
if args.transformer_path is not None:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_path)
else:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.model_path,
subfolder="transformer/",
torch_dtype=weight_dtype)
pipe = HunyuanVideoPipeline.from_pretrained(args.model_path, transformer=transformer, torch_dtype=weight_dtype)
pipe.enable_vae_tiling()
if args.lora_checkpoint_dir is not None:
print(f"Loading LoRA weights from {args.lora_checkpoint_dir}")
config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json")
with open(config_path, "r") as f:
lora_config_dict = json.load(f)
rank = lora_config_dict["lora_params"]["lora_rank"]
lora_alpha = lora_config_dict["lora_params"]["lora_alpha"]
lora_scaling = lora_alpha / rank
pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default")
pipe.set_adapters(["default"], [lora_scaling])
print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}")
if args.cpu_offload:
pipe.enable_model_cpu_offload(device)
else:
pipe.to(device)
# Generate videos from the input prompt
if args.prompt_embed_path is not None:
prompt_embeds = (torch.load(args.prompt_embed_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
encoder_attention_mask = (torch.load(args.encoder_attention_mask_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
prompts = None
elif args.prompt_path is not None:
prompts = [line.strip() for line in open(args.prompt_path, "r")]
prompt_embeds = None
encoder_attention_mask = None
else:
prompts = args.prompts
prompt_embeds = None
encoder_attention_mask = None
if prompts is not None:
with torch.autocast("cuda", dtype=torch.bfloat16):
for prompt in prompts:
generator = torch.Generator("cpu").manual_seed(args.seed)
video = pipe(
prompt=[prompt],
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames
if nccl_info.global_rank <= 0:
os.makedirs(args.output_path, exist_ok=True)
suffix = prompt.split(".")[0]
export_to_video(
video[0],
os.path.join(args.output_path, f"{suffix}.mp4"),
fps=24,
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
generator = torch.Generator("cpu").manual_seed(args.seed)
videos = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=encoder_attention_mask,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames
if nccl_info.global_rank <= 0:
export_to_video(videos[0], args.output_path + ".mp4", fps=24)
def inference_quantization(args):
torch.manual_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = args.model_path
if args.quantization == "nf4":
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
llm_int8_skip_modules=["proj_out", "norm_out"])
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id,
subfolder="transformer/",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config)
if args.quantization == "int8":
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"])
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id,
subfolder="transformer/",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config)
elif not args.quantization:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id,
subfolder="transformer/",
torch_dtype=torch.bfloat16).to(device)
print("Max vram for read transformer:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), "GiB")
torch.cuda.reset_max_memory_allocated(device)
if not args.cpu_offload:
pipe = HunyuanVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
pipe.transformer = transformer
else:
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
torch.cuda.reset_max_memory_allocated(device)
pipe.scheduler._shift = args.flow_shift
pipe.vae.enable_tiling()
if args.cpu_offload:
pipe.enable_model_cpu_offload()
print("Max vram for init pipeline:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), "GiB")
if args.prompt.endswith('.txt'):
with open(args.prompt) as f:
prompts = [line.strip() for line in f.readlines()]
else:
prompts = [args.prompt]
generator = torch.Generator("cpu").manual_seed(args.seed)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
torch.cuda.reset_max_memory_allocated(device)
for prompt in prompts:
start_time = time.perf_counter()
output = pipe(
prompt=prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames[0]
export_to_video(output, os.path.join(args.output_path, f"{prompt[:100]}.mp4"), fps=args.fps)
print("Time:", round(time.perf_counter() - start_time, 2), "seconds")
print("Max vram for denoise:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), "GiB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic parameters
parser.add_argument("--prompt", type=str, help="prompt file for inference")
parser.add_argument("--prompt_embed_path", type=str, default=None)
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_path", type=str, default="data/hunyuan")
parser.add_argument("--transformer_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default="./outputs/video")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--lora_checkpoint_dir",
type=str,
default=None,
help="Path to the directory containing LoRA checkpoints",
)
# Additional parameters
parser.add_argument(
"--denoise-type",
type=str,
default="flow",
help="Denoise type for noised inputs.",
)
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument("--neg_prompt", type=str, default=None, help="Negative prompt for sampling.")
parser.add_argument(
"--guidance_scale",
type=float,
default=1.0,
help="Classifier free guidance scale.",
)
parser.add_argument(
"--embedded_cfg_scale",
type=float,
default=6.0,
help="Embedded classifier free guidance scale.",
)
parser.add_argument("--flow_shift", type=int, default=7, help="Flow shift parameter.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference.")
parser.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate per prompt.",
)
parser.add_argument(
"--load-key",
type=str,
default="module",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
parser.add_argument(
"--dit-weight",
type=str,
default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
)
parser.add_argument(
"--reproduce",
action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)
# Flow Matching
parser.add_argument(
"--flow-reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
parser.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
parser.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
parser.add_argument(
"--linear-schedule-end",
type=int,
default=25,
help="End step for linear quadratic schedule for flow matching.",
)
# Model parameters
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument("--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"])
parser.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument("--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)
parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument(
"--text-encoder-precision",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video")
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")
parser.add_argument("--text-encoder-2", type=str, default="clipL")
parser.add_argument(
"--text-encoder-precision-2",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
)
parser.add_argument("--text-states-dim-2", type=int, default=768)
parser.add_argument("--tokenizer-2", type=str, default="clipL")
parser.add_argument("--text-len-2", type=int, default=77)
args = parser.parse_args()
if args.quantization:
inference_quantization(args)
else:
inference(args)
import argparse
import json
import os
import torch
import torch.distributed as dist
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def initialize_distributed():
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def main(args):
initialize_distributed()
print(nccl_info.sp_size)
device = torch.cuda.current_device()
# Peiyuan: GPU seed will cause A100 and H100 to produce different results .....
if args.scheduler_type == "euler":
scheduler = FlowMatchEulerDiscreteScheduler()
else:
linear_quadratic = True if "linear_quadratic" in args.scheduler_type else False
scheduler = PCMFMScheduler(
1000,
args.shift,
args.num_euler_timesteps,
linear_quadratic,
args.linear_threshold,
args.linear_range,
)
if args.transformer_path is not None:
transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path)
else:
transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder="transformer/")
pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler)
pipe.enable_vae_tiling()
if args.lora_checkpoint_dir is not None:
print(f"Loading LoRA weights from {args.lora_checkpoint_dir}")
config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json")
with open(config_path, "r") as f:
lora_config_dict = json.load(f)
rank = lora_config_dict["lora_params"]["lora_rank"]
lora_alpha = lora_config_dict["lora_params"]["lora_alpha"]
lora_scaling = lora_alpha / rank
pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default")
pipe.set_adapters(["default"], [lora_scaling])
print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}")
# pipe.to(device)
pipe.enable_model_cpu_offload(device)
# Generate videos from the input prompt
if args.prompt_embed_path is not None:
prompt_embeds = (torch.load(args.prompt_embed_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
encoder_attention_mask = (torch.load(args.encoder_attention_mask_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
prompts = None
elif args.prompt_path is not None:
prompts = [line.strip() for line in open(args.prompt_path, "r")]
prompt_embeds = None
encoder_attention_mask = None
else:
prompts = args.prompts
prompt_embeds = None
encoder_attention_mask = None
if prompts is not None:
with torch.autocast("cuda", dtype=torch.bfloat16):
for prompt in prompts:
generator = torch.Generator("cpu").manual_seed(args.seed)
video = pipe(
prompt=[prompt],
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=generator,
).frames
if nccl_info.global_rank <= 0:
os.makedirs(args.output_path, exist_ok=True)
suffix = prompt.split(".")[0]
export_to_video(
video[0],
os.path.join(args.output_path, f"{suffix}.mp4"),
fps=30,
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
generator = torch.Generator("cpu").manual_seed(args.seed)
videos = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=encoder_attention_mask,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=generator,
).frames
if nccl_info.global_rank <= 0:
export_to_video(videos[0], args.output_path + ".mp4", fps=30)
if __name__ == "__main__":
# arg parse
parser = argparse.ArgumentParser()
parser.add_argument("--prompts", nargs="+", default=[])
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=848)
parser.add_argument("--num_inference_steps", type=int, default=64)
parser.add_argument("--guidance_scale", type=float, default=4.5)
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output_path", type=str, default="./outputs.mp4")
parser.add_argument("--transformer_path", type=str, default=None)
parser.add_argument("--prompt_embed_path", type=str, default=None)
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--scheduler_type", type=str, default="euler")
parser.add_argument("--encoder_attention_mask_path", type=str, default=None)
parser.add_argument(
"--lora_checkpoint_dir",
type=str,
default=None,
help="Path to the directory containing LoRA checkpoints",
)
parser.add_argument("--shift", type=float, default=8.0)
parser.add_argument("--num_euler_timesteps", type=int, default=100)
parser.add_argument("--linear_threshold", type=float, default=0.025)
parser.add_argument("--linear_range", type=float, default=0.5)
args = parser.parse_args()
main(args)
import argparse
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
def main(args):
# Set the random seed for reproducibility
generator = torch.Generator("cuda").manual_seed(args.seed)
# do not invert
scheduler = FlowMatchEulerDiscreteScheduler()
if args.transformer_path is not None:
transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path)
else:
transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder="transformer/")
pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler)
pipe.enable_vae_tiling()
# pipe.to("cuda:1")
pipe.enable_model_cpu_offload()
# Generate videos from the input prompt
with torch.autocast("cuda", dtype=torch.bfloat16):
videos = pipe(
prompt=args.prompts,
height=args.height,
width=args.width,
num_frames=args.num_frames,
generator=generator,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
).frames
for prompt, video in zip(args.prompts, videos):
export_to_video(video, args.output_path + f"_{prompt}.mp4", fps=30)
if __name__ == "__main__":
# arg parse
parser = argparse.ArgumentParser()
parser.add_argument("--prompts", nargs="+", default=[])
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=848)
parser.add_argument("--num_inference_steps", type=int, default=64)
parser.add_argument("--guidance_scale", type=float, default=4.5)
parser.add_argument("--model_path", type=str, default="data/mochi")
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--transformer_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default="./outputs.mp4")
args = parser.parse_args()
main(args)
import argparse
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from fastvideo.models.stepvideo.diffusion.scheduler import FlowMatchDiscreteScheduler
from fastvideo.models.stepvideo.diffusion.video_pipeline import StepVideoPipeline
from fastvideo.models.stepvideo.modules.model import StepVideoModel
from fastvideo.models.stepvideo.utils import setup_seed
from fastvideo.models.stepvideo.utils.quantization import convert_fp8_linear, fp8_linear_forward
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def initialize_distributed():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="StepVideo inference script")
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
args = parser.parse_args(namespace=namespace)
return args
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Extra models args, including vae, text encoders and tokenizers)")
group.add_argument(
"--vae_url",
type=str,
default='127.0.0.1',
help="vae url.",
)
group.add_argument(
"--caption_url",
type=str,
default='127.0.0.1',
help="caption url.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
# Flow Matching
group.add_argument(
"--time_shift",
type=float,
default=13,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow_reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow_solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
# ======================== Model loads ========================
group.add_argument(
"--model_dir",
type=str,
default="./ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--model_resolution",
type=str,
default="540p",
choices=["540p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
group.add_argument(
"--use-fp8",
action="store_true",
help="FP8 Quantization for single GPU support.",
)
# ======================== Inference general setting ========================
group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer_steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--save_path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--name_suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
# ---sample size---
group.add_argument(
"--num_frames",
type=int,
default=204,
help="How many frames to sample from a video. ",
)
group.add_argument(
"--height",
type=int,
default=768,
help="The height of video sample",
)
group.add_argument(
"--width",
type=int,
default=768,
help="The width of video sample",
)
# --- prompt ---
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument("--seed", type=int, default=1234, help="Seed for evaluation.")
# Classifier-Free Guidance
group.add_argument("--pos_magic",
type=str,
default="超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
help="Positive magic prompt for sampling.")
group.add_argument("--neg_magic",
type=str,
default="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
help="Negative magic prompt for sampling.")
group.add_argument("--cfg_scale", type=float, default=9.0, help="Classifier free guidance scale.")
return parser
if __name__ == "__main__":
args = parse_args()
initialize_distributed()
main_print(f"sequence parallel size: {nccl_info.sp_size}")
device = torch.cuda.current_device()
setup_seed(args.seed)
main_print("Loading model, this might take a while...")
scheduler = FlowMatchDiscreteScheduler()
if args.use_fp8:
assert int(os.getenv("WORLD_SIZE", 1)) == 1
transformer = StepVideoModel.from_pretrained(os.path.join(args.model_dir, "transformer"),
torch_dtype=torch.bfloat16,
device="cpu")
if not os.path.exists(args.model_dir + "/fp8_transformer.pth"):
print("no_fp8 weight, creating...")
scale_dict = convert_fp8_linear(transformer, torch.bfloat16)
torch.save(transformer.state_dict(), args.model_dir + "/fp8_transformer.pth")
torch.save(scale_dict, args.model_dir + "/fp8_scale_dict.pth")
else:
transformer.load_state_dict(torch.load(args.model_dir + "/fp8_transformer.pth"))
scale_dict = torch.load(args.model_dir + "/fp8_scale_dict.pth")
original_dtype = torch.bfloat16
for key, layer in transformer.named_modules():
if isinstance(layer, nn.Linear) and 'transformer_blocks' in key and key in scale_dict:
layer.weight.data = layer.weight.data.to(torch.float8_e4m3fn)
print(f"{key}, layer.weight.dtype: {layer.weight.dtype}")
original_forward = layer.forward
scale = scale_dict[key]
setattr(layer, "fp8_scale", scale.to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
else:
transformer = StepVideoModel.from_pretrained(os.path.join(args.model_dir, "transformer"),
torch_dtype=torch.bfloat16,
device=device)
transformer = transformer.to(device)
pipeline = StepVideoPipeline(transformer, scheduler, save_path=args.save_path)
pipeline.setup_api(
vae_url=args.vae_url,
caption_url=args.caption_url,
)
if args.prompt.endswith('.txt'):
with open(args.prompt) as f:
prompts = [line.strip() for line in f.readlines()]
else:
prompts = [args.prompt]
for prompt in prompts:
videos = pipeline(prompt=prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
time_shift=args.time_shift,
pos_magic=args.pos_magic,
neg_magic=args.neg_magic,
output_file_name=prompt[:50])
dist.destroy_process_group()
import argparse
import json
import os
import types
from typing import Dict, Optional
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange, repeat
from fastvideo.models.stepvideo.diffusion.scheduler import FlowMatchDiscreteScheduler
from fastvideo.models.stepvideo.diffusion.video_pipeline import StepVideoPipeline
from fastvideo.models.stepvideo.modules.model import StepVideoModel
from fastvideo.models.stepvideo.utils import setup_seed
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import initialize_sequence_parallel_state, nccl_info
def initialize_distributed():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
main_print(f"world_size: {world_size}")
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
initialize_sequence_parallel_state(world_size)
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="StepVideo inference script")
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
args = parser.parse_args(namespace=namespace)
return args
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Extra models args, including vae, text encoders and tokenizers)")
group.add_argument(
"--vae_url",
type=str,
default='127.0.0.1',
help="vae url.",
)
group.add_argument(
"--caption_url",
type=str,
default='127.0.0.1',
help="caption url.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
# Flow Matching
group.add_argument(
"--time_shift",
type=float,
default=13,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow_reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow_solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
# ======================== Model loads ========================
group.add_argument(
"--model_dir",
type=str,
default="./ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--model_resolution",
type=str,
default="540p",
choices=["540p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
# ======================== Inference general setting ========================
group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer_steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--save_path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--name_suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
# ---sample size---
group.add_argument(
"--num_frames",
type=int,
default=204,
help="How many frames to sample from a video. ",
)
group.add_argument(
"--height",
type=int,
default=768,
help="The height of video sample",
)
group.add_argument(
"--width",
type=int,
default=768,
help="The width of video sample",
)
# --- prompt ---
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument("--seed", type=int, default=1234, help="Seed for evaluation.")
# Classifier-Free Guidance
group.add_argument("--pos_magic",
type=str,
default="超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
help="Positive magic prompt for sampling.")
group.add_argument("--neg_magic",
type=str,
default="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
help="Negative magic prompt for sampling.")
group.add_argument("--cfg_scale", type=float, default=9.0, help="Classifier free guidance scale.")
group.add_argument("--mask_search_files_path", type=str, default="assets/mask_strategy.json")
group.add_argument("--mask_strategy_file_path", type=str, default="assets/mask_strategy_stepvideo.json")
group.add_argument("--skip_time_steps", type=int, default=10)
group.add_argument(
"--mask_strategy_selected",
type=lambda x: [int(i) for i in x.strip('[]').split(',')], # Convert string to list of integers
default=[1, 2, 6], # Now can be directly set as a list
help="order of candidates")
parser.add_argument(
"--rel_l1_thresh",
type=float,
default=0,
help="0.22 for 1.67x speedup, 0.23 for 2.1x speedup",
)
parser.add_argument(
"--enable_teacache",
action="store_true",
help="Use teacache for speeding up inference",
)
return parser
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
fps: torch.Tensor = None,
return_dict: bool = True,
mask_strategy=None,
):
assert hidden_states.ndim == 5
"hidden_states's shape should be (bsz, f, ch, h ,w)"
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states)
len_frame = hidden_states.shape[1]
if self.use_additional_conditions:
added_cond_kwargs = {
"resolution": torch.tensor([(height, width)] * bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"nframe": torch.tensor([frame] * bsz, device=hidden_states.device, dtype=hidden_states.dtype),
"fps": fps
}
else:
added_cond_kwargs = {}
timestep, embedded_timestep = self.adaln_single(timestep, added_cond_kwargs=added_cond_kwargs)
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
clip_embedding = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask,
encoder_hidden_states,
q_seqlen=frame * len_frame)
if self.enable_teacache:
hidden_states_ = hidden_states.clone()
normed_hidden_states = self.transformer_blocks[0].norm1(hidden_states_)
normed_hidden_states = rearrange(normed_hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
modulated_inp = normed_hidden_states * (1 + scale) + shift
if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [6.74352814e+03, -2.22814115e+03, 2.55029094e+02, -1.12338285e+01, 2.84921593e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
((modulated_inp - self.previous_modulated_input).abs().mean() /
self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
# print(f"accumulated_rel_l1_distance: {self.accumulated_rel_l1_distance}")
should_calc = False
else:
# print(f"accumulated_rel_l1_distance: {self.accumulated_rel_l1_distance}")
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
# print(f"skip step {self.cnt}")
hidden_states += self.previous_residual
else:
# print(f"calc step {self.cnt}")
ori_hidden_states = hidden_states.clone()
hidden_states = self.block_forward(hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel,
mask_strategy=mask_strategy)
self.previous_residual = hidden_states - ori_hidden_states
else:
# --------------------- Pass through DiT blocks ------------------------
hidden_states = self.block_forward(hidden_states,
encoder_hidden_states,
timestep=timestep,
rope_positions=[frame, height, width],
attn_mask=attn_mask,
parallel=self.parallel,
mask_strategy=mask_strategy)
# ---------------------------- Final layer ------------------------------
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(shape=(-1, height, width, self.patch_size, self.patch_size,
self.out_channels))
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
output = hidden_states.reshape(shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size))
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
if return_dict:
return {'x': output}
return output
if __name__ == "__main__":
args = parse_args()
initialize_distributed()
main_print(f"sequence parallel size: {nccl_info.sp_size}")
device = torch.cuda.current_device()
setup_seed(args.seed)
main_print("Loading model, this might take a while...")
transformer = StepVideoModel.from_pretrained(os.path.join(args.model_dir, "transformer"),
torch_dtype=torch.bfloat16,
device_map=device)
if args.enable_teacache:
transformer.forward = types.MethodType(teacache_forward, transformer)
scheduler = FlowMatchDiscreteScheduler()
pipeline = StepVideoPipeline(transformer, scheduler, save_path=args.save_path)
pipeline.setup_api(
vae_url=args.vae_url,
caption_url=args.caption_url,
)
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = args.infer_steps
pipeline.transformer.__class__.rel_l1_thresh = args.rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
pipeline.transformer.__class__.previous_modulated_input = None
pipeline.transformer.__class__.previous_residual = None
with open(args.mask_strategy_file_path, 'r') as f:
mask_strategy = json.load(f)
if args.prompt.endswith('.txt'):
with open(args.prompt) as f:
prompts = [line.strip() for line in f.readlines()]
else:
prompts = [args.prompt]
for prompt in prompts:
main_print(f"Generating video for prompt: {prompt}")
videos = pipeline(prompt=prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
time_shift=args.time_shift,
pos_magic=args.pos_magic,
neg_magic=args.neg_magic,
output_file_name=prompt[:150],
mask_strategy=mask_strategy)
dist.destroy_process_group()
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft
from peft import LoraConfig, set_peft_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
torch._dynamo.config.capture_scalar_outputs = True
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
generator,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean,
std=logit_std,
size=(batch_size, ),
device="cpu",
generator=generator,
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
else:
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
return u
def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
# print("timesteps:",timesteps)
# print("schedule_timesteps:",schedule_timesteps)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def train_one_step(
transformer,
model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
precondition_outputs,
max_grad_norm,
weighting_scheme,
logit_mean,
logit_std,
mode_scale,
):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
latents = normalize_dit_input(model_type, latents)
batch_size = latents.shape[0]
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
if sp_size > 1:
# Make sure that the timesteps are the same across all sp processes.
broadcast(timesteps)
sigmas = get_sigmas(
noise_scheduler,
latents.device,
timesteps,
n_dim=latents.ndim,
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor([1000.0], device=noisy_model_input.device, dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]
if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
if precondition_outputs:
target = latents
else:
target = noise - latents
loss = (torch.mean((model_pred.float() - target.float())**2) / gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item()
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
print("<"*50)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
print(">"*50)
if args.use_lora:
assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if args.model_type == "mochi":
pipe = MochiPipeline
elif args.model_type == "hunyuan_hf":
pipe = HunyuanVideoPipeline
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.resume_from_lora_checkpoint:
lora_state_dict = pipe.lora_state_dict(args.resume_from_lora_checkpoint)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
main_print(f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. ")
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = [no_split_module.__name__ for no_split_module in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
fsdp_kwargs['use_orig_params'] = True
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
# transformer = torch.compile(transformer)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
main_print(transformer)
# Set model as trainable.
transformer.train()
noise_scheduler = FlowMatchEulerDiscreteScheduler()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
#if False:
if step == 100:
from torch.profiler import profile, record_function, ProfilerActivity
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],
record_shapes=True, profile_memory=False, with_stack=False) as prof:
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
prof.export_chrome_trace(f"/public/home/wuxk/code/modelzoo/FastVideo-main/scripts/finetune/prof/bw_fv_trace_ge_{dist.get_rank()}.json")
# torch.cuda.synchronize()
else:
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
main_print(f"zll step_time: {step_time:.2f}s avg_step_time: {sum(step_times) / len(step_times)}")
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step, pipe)
else:
# Your existing checkpoint saving code
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type",
type=str,
default="mochi",
help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str, default=None)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--uncond_prompt_dir", type=str)
parser.add_argument(
"--validation_sampling_steps",
type=str,
default="64",
help="use ',' to split multi sampling steps",
)
parser.add_argument(
"--validation_guidance_scale",
type=str,
default="4.5",
help="use ',' to split multi scale",
)
parser.add_argument("--validation_steps", type=int, default=50)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model."))
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to apply.")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft
from peft import LoraConfig, set_peft_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
generator,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean,
std=logit_std,
size=(batch_size, ),
device="cpu",
generator=generator,
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
else:
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
return u
def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def train_one_step(
transformer,
model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
precondition_outputs,
max_grad_norm,
weighting_scheme,
logit_mean,
logit_std,
mode_scale,
):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
latents = normalize_dit_input(model_type, latents)
batch_size = latents.shape[0]
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
if sp_size > 1:
# Make sure that the timesteps are the same across all sp processes.
broadcast(timesteps)
sigmas = get_sigmas(
noise_scheduler,
latents.device,
timesteps,
n_dim=latents.ndim,
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor([1000.0], device=noisy_model_input.device, dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]
if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
if precondition_outputs:
target = latents
else:
target = noise - latents
loss = (torch.mean((model_pred.float() - target.float())**2) / gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item()
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
print("<"*50)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
print(">"*50)
if args.use_lora:
assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if args.model_type == "mochi":
pipe = MochiPipeline
elif args.model_type == "hunyuan_hf":
pipe = HunyuanVideoPipeline
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.resume_from_lora_checkpoint:
lora_state_dict = pipe.lora_state_dict(args.resume_from_lora_checkpoint)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
main_print(f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. ")
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = [no_split_module.__name__ for no_split_module in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
# Set model as trainable.
transformer.train()
noise_scheduler = FlowMatchEulerDiscreteScheduler()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step, pipe)
else:
# Your existing checkpoint saving code
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type",
type=str,
default="mochi",
help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str, default=None)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--uncond_prompt_dir", type=str)
parser.add_argument(
"--validation_sampling_steps",
type=str,
default="64",
help="use ',' to split multi sampling steps",
)
parser.add_argument(
"--validation_guidance_scale",
type=str,
default="4.5",
help="use ',' to split multi scale",
)
parser.add_argument("--validation_steps", type=int, default=50)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model."))
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to apply.")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft
from peft import LoraConfig, set_peft_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
import torch
torch._dynamo.config.capture_scalar_outputs = True
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
generator,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean,
std=logit_std,
size=(batch_size, ),
device="cpu",
generator=generator,
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
else:
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
return u
def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def train_one_step(
transformer,
model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
precondition_outputs,
max_grad_norm,
weighting_scheme,
logit_mean,
logit_std,
mode_scale,
):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
latents = normalize_dit_input(model_type, latents)
batch_size = latents.shape[0]
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
if sp_size > 1:
# Make sure that the timesteps are the same across all sp processes.
broadcast(timesteps)
sigmas = get_sigmas(
noise_scheduler,
latents.device,
timesteps,
n_dim=latents.ndim,
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor([1000.0], device=noisy_model_input.device, dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]
if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
if precondition_outputs:
target = latents
else:
target = noise - latents
loss = (torch.mean((model_pred.float() - target.float())**2) / gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item()
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
print("<"*50)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
print(">"*50)
if args.use_lora:
assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if args.model_type == "mochi":
pipe = MochiPipeline
elif args.model_type == "hunyuan_hf":
pipe = HunyuanVideoPipeline
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.resume_from_lora_checkpoint:
lora_state_dict = pipe.lora_state_dict(args.resume_from_lora_checkpoint)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
main_print(f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. ")
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = [no_split_module.__name__ for no_split_module in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
fsdp_kwargs['use_orig_params'] = True
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
#transformer = torch.compile(transformer)
# Set model as trainable.
transformer.train()
noise_scheduler = FlowMatchEulerDiscreteScheduler()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
main_print(f"zll step_time: {step_time:.2f}s avg_step_time: {sum(step_times) / len(step_times)}")
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step, pipe)
else:
# Your existing checkpoint saving code
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type",
type=str,
default="mochi",
help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str, default=None)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--uncond_prompt_dir", type=str)
parser.add_argument(
"--validation_sampling_steps",
type=str,
default="64",
help="use ',' to split multi sampling steps",
)
parser.add_argument(
"--validation_guidance_scale",
type=str,
default="4.5",
help="use ',' to split multi scale",
)
parser.add_argument("--validation_steps", type=int, default=50)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model."))
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to apply.")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft
from peft import LoraConfig, set_peft_model_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
from torch.profiler import profile, record_function, ProfilerActivity
def trace_handler(p):
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)
print(output)
rank = dist.get_rank()
p.export_chrome_trace("/public/hy-code/FastVideo-main/scripts/finetune/prof/BW_amd" + str(rank) +"_"+ str(p.step_num) + ".json")
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
generator,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean,
std=logit_std,
size=(batch_size, ),
device="cpu",
generator=generator,
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u)
else:
u = torch.rand(size=(batch_size, ), device="cpu", generator=generator)
return u
def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def train_one_step(
transformer,
model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
gradient_accumulation_steps,
sp_size,
precondition_outputs,
max_grad_norm,
weighting_scheme,
logit_mean,
logit_std,
mode_scale,
):
total_loss = 0.0
optimizer.zero_grad()
for _ in range(gradient_accumulation_steps):
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
latents = normalize_dit_input(model_type, latents)
batch_size = latents.shape[0]
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=batch_size,
generator=noise_random_generator,
logit_mean=logit_mean,
logit_std=logit_std,
mode_scale=mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=latents.device)
if sp_size > 1:
# Make sure that the timesteps are the same across all sp processes.
broadcast(timesteps)
sigmas = get_sigmas(
noise_scheduler,
latents.device,
timesteps,
n_dim=latents.ndim,
dtype=latents.dtype,
)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda", dtype=torch.bfloat16):
input_kwargs = {
"hidden_states": noisy_model_input,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timesteps,
"encoder_attention_mask": encoder_attention_mask, # B, L
"return_dict": False,
}
if 'hunyuan' in model_type:
input_kwargs["guidance"] = torch.tensor([1000.0], device=noisy_model_input.device, dtype=torch.bfloat16)
model_pred = transformer(**input_kwargs)[0]
if precondition_outputs:
model_pred = noisy_model_input - model_pred * sigmas
if precondition_outputs:
target = latents
else:
target = noise - latents
loss = (torch.mean((model_pred.float() - target.float())**2) / gradient_accumulation_steps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
grad_norm = transformer.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
return total_loss, grad_norm.item()
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
print("<"*50)
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
print(">"*50)
if args.use_lora:
assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning"
if args.model_type == "mochi":
pipe = MochiPipeline
elif args.model_type == "hunyuan_hf":
pipe = HunyuanVideoPipeline
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.resume_from_lora_checkpoint:
lora_state_dict = pipe.lora_state_dict(args.resume_from_lora_checkpoint)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
main_print(f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. ")
main_print(
f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M")
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
if args.use_lora:
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = [no_split_module.__name__ for no_split_module in no_split_modules]
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
# Set model as trainable.
transformer.train()
noise_scheduler = FlowMatchEulerDiscreteScheduler()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# if rank <= 0:
# project = args.tracker_project_name or "fastvideo"
# wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
assert NotImplementedError("resume_from_checkpoint is not supported now.")
# TODO
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# todo future
for i in range(init_steps):
next(loader)
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=10,warmup=5,active=1),on_trace_ready=trace_handler) as p:
for step in range(init_steps + 1, args.max_train_steps + 1):
start_time = time.time()
loss, grad_norm = train_one_step(
transformer,
args.model_type,
optimizer,
lr_scheduler,
loader,
noise_scheduler,
noise_random_generator,
args.gradient_accumulation_steps,
args.sp_size,
args.precondition_outputs,
args.max_grad_norm,
args.weighting_scheme,
args.logit_mean,
args.logit_std,
args.mode_scale,
)
p.step()
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "train_loss": loss,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# "grad_norm": grad_norm,
# },
# step=step,
# )
if step % args.checkpointing_steps == 0:
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step, pipe)
else:
# Your existing checkpoint saving code
save_checkpoint(transformer, rank, args.output_dir, step)
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type",
type=str,
default="mochi",
help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str, default=None)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
# validation & logs
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--uncond_prompt_dir", type=str)
parser.add_argument(
"--validation_sampling_steps",
type=str,
default="64",
help="use ',' to split multi sampling steps",
)
parser.add_argument(
"--validation_guidance_scale",
type=str,
default="4.5",
help="use ',' to split multi scale",
)
parser.add_argument("--validation_steps", type=int, default=50)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model."))
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to apply.")
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
# import
import json
import os
import torch
import torch.distributed.checkpoint as dist_cp
from peft import get_peft_model_state_dict
from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from fastvideo.utils.logging_ import main_print
def save_checkpoint_optimizer(model, optimizer, rank, output_dir, step, discriminator=False):
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = model.state_dict()
optim_state = FSDP.optim_state_dict(
model,
optimizer,
)
# todo move to get_state_dict
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
# save using safetensors
if rank <= 0 and not discriminator:
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
config_dict = dict(model.config)
config_dict.pop('dtype')
config_path = os.path.join(save_dir, "config.json")
# save dict as json
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
optimizer_path = os.path.join(save_dir, "optimizer.pt")
torch.save(optim_state, optimizer_path)
else:
weight_path = os.path.join(save_dir, "discriminator_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt")
torch.save(optim_state, optimizer_path)
main_print(f"--> checkpoint saved at step {step}")
def save_checkpoint(transformer, rank, output_dir, step):
main_print(f"--> saving checkpoint at step {step}")
with FSDP.state_dict_type(
transformer,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = transformer.state_dict()
# todo move to get_state_dict
if rank <= 0:
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
# save using safetensors
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
config_dict = dict(transformer.config)
if "dtype" in config_dict:
del config_dict["dtype"] # TODO
config_path = os.path.join(save_dir, "config.json")
# save dict as json
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
main_print(f"--> checkpoint saved at step {step}")
def save_checkpoint_generator_discriminator(
model,
optimizer,
discriminator,
discriminator_optimizer,
rank,
output_dir,
step,
):
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = model.state_dict()
# todo move to get_state_dict
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
hf_weight_dir = os.path.join(save_dir, "hf_weights")
os.makedirs(hf_weight_dir, exist_ok=True)
# save using safetensors
if rank <= 0:
config_dict = dict(model.config)
config_path = os.path.join(hf_weight_dir, "config.json")
# save dict as json
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
weight_path = os.path.join(hf_weight_dir, "diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}")
model_weight_dir = os.path.join(save_dir, "model_weights_state")
os.makedirs(model_weight_dir, exist_ok=True)
model_optimizer_dir = os.path.join(save_dir, "model_optimizer_state")
os.makedirs(model_optimizer_dir, exist_ok=True)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
optim_state = FSDP.optim_state_dict(model, optimizer)
model_state = model.state_dict()
weight_state_dict = {"model": model_state}
dist_cp.save_state_dict(
state_dict=weight_state_dict,
storage_writer=dist_cp.FileSystemWriter(model_weight_dir),
planner=DefaultSavePlanner(),
)
optimizer_state_dict = {"optimizer": optim_state}
dist_cp.save_state_dict(
state_dict=optimizer_state_dict,
storage_writer=dist_cp.FileSystemWriter(model_optimizer_dir),
planner=DefaultSavePlanner(),
)
discriminator_fsdp_state_dir = os.path.join(save_dir, "discriminator_fsdp_state")
os.makedirs(discriminator_fsdp_state_dir, exist_ok=True)
with FSDP.state_dict_type(
discriminator,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
optim_state = FSDP.optim_state_dict(discriminator, discriminator_optimizer)
model_state = discriminator.state_dict()
state_dict = {"optimizer": optim_state, "model": model_state}
if rank <= 0:
discriminator_fsdp_state_fil = os.path.join(discriminator_fsdp_state_dir, "discriminator_state.pt")
torch.save(state_dict, discriminator_fsdp_state_fil)
main_print("--> saved FSDP state checkpoint")
def load_sharded_model(model, optimizer, model_dir, optimizer_dir):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
weight_state_dict = {"model": model.state_dict()}
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=weight_state_dict["model"],
optimizer_key="optimizer",
storage_reader=dist_cp.FileSystemReader(optimizer_dir),
)
optim_state = optim_state["optimizer"]
flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
optimizer.load_state_dict(flattened_osd)
dist_cp.load_state_dict(
state_dict=weight_state_dict,
storage_reader=dist_cp.FileSystemReader(model_dir),
planner=DefaultLoadPlanner(),
)
model_state = weight_state_dict["model"]
model.load_state_dict(model_state)
main_print(f"--> loaded model and optimizer from path {model_dir}")
return model, optimizer
def load_full_state_model(model, optimizer, checkpoint_file, rank):
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
discriminator_state = torch.load(checkpoint_file)
model_state = discriminator_state["model"]
if rank <= 0:
optim_state = discriminator_state["optimizer"]
else:
optim_state = None
model.load_state_dict(model_state)
discriminator_optim_state = FSDP.optim_state_dict_to_load(model=model,
optim=optimizer,
optim_state_dict=optim_state)
optimizer.load_state_dict(discriminator_optim_state)
main_print(f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}")
return model, optimizer
def resume_training_generator_discriminator(model, optimizer, discriminator, discriminator_optimizer, checkpoint_dir,
rank):
step = int(checkpoint_dir.split("-")[-1])
model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state")
model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state")
model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, model_optimizer_dir)
discriminator_ckpt_file = os.path.join(checkpoint_dir, "discriminator_fsdp_state", "discriminator_state.pt")
discriminator, discriminator_optimizer = load_full_state_model(discriminator, discriminator_optimizer,
discriminator_ckpt_file, rank)
return model, optimizer, discriminator, discriminator_optimizer, step
def resume_training(model, optimizer, checkpoint_dir, discriminator=False):
weight_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors")
if discriminator:
weight_path = os.path.join(checkpoint_dir, "discriminator_pytorch_model.safetensors")
model_weights = load_file(weight_path)
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
current_state = model.state_dict()
current_state.update(model_weights)
model.load_state_dict(current_state, strict=False)
if discriminator:
optim_path = os.path.join(checkpoint_dir, "discriminator_optimizer.pt")
else:
optim_path = os.path.join(checkpoint_dir, "optimizer.pt")
optimizer_state_dict = torch.load(optim_path, weights_only=False)
optim_state = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optimizer_state_dict)
optimizer.load_state_dict(optim_state)
step = int(checkpoint_dir.split("-")[-1])
return model, optimizer, step
def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipeline):
with FSDP.state_dict_type(
transformer,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
full_state_dict = transformer.state_dict()
lora_optim_state = FSDP.optim_state_dict(
transformer,
optimizer,
)
if rank <= 0:
save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
# save optimizer
optim_path = os.path.join(save_dir, "lora_optimizer.pt")
torch.save(lora_optim_state, optim_path)
# save lora weight
main_print(f"--> saving LoRA checkpoint at step {step}")
transformer_lora_layers = get_peft_model_state_dict(model=transformer, state_dict=full_state_dict)
pipeline.save_lora_weights(
save_directory=save_dir,
transformer_lora_layers=transformer_lora_layers,
is_main_process=True,
)
# save config
lora_config = {
"step": step,
"lora_params": {
"lora_rank": transformer.config.lora_rank,
"lora_alpha": transformer.config.lora_alpha,
"target_modules": transformer.config.lora_target_modules,
},
}
config_path = os.path.join(save_dir, "lora_config.json")
with open(config_path, "w") as f:
json.dump(lora_config, f, indent=4)
main_print(f"--> LoRA checkpoint saved at step {step}")
def resume_lora_optimizer(transformer, checkpoint_dir, optimizer):
config_path = os.path.join(checkpoint_dir, "lora_config.json")
with open(config_path, "r") as f:
config_dict = json.load(f)
optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt")
optimizer_state_dict = torch.load(optim_path, weights_only=False)
optim_state = FSDP.optim_state_dict_to_load(model=transformer,
optim=optimizer,
optim_state_dict=optimizer_state_dict)
optimizer.load_state_dict(optim_state)
step = config_dict["step"]
main_print(f"--> Successfully resuming LoRA optimizer from step {step}")
return transformer, optimizer, step
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from fastvideo.utils.parallel_states import nccl_info
def broadcast(input_: torch.Tensor):
src = nccl_info.group_id * nccl_info.sp_size
dist.broadcast(input_, src=src, group=nccl_info.group)
def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
"""
all-to-all for QKV
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
seq_world_size = dist.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input.shape
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous())
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group,async_op=True)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, bs, shard_hc, hs)
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
return output
elif scatter_idx == 1 and gather_idx == 2:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, shard_hc, hs = input.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size
seq_world_size = dist.get_world_size(group)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc,
hs).transpose(0,
3).transpose(0,
1).contiguous().reshape(seq_world_size, shard_hc,
shard_seqlen, bs, hs))
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, bs, hs)
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
return output
else:
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (
None,
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
None,
None,
)
def all_to_all_4D(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
class _AllGather(torch.autograd.Function):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
world_size = nccl_info.sp_size
group = nccl_info.group
input_size = list(input_.size())
ctx.input_size = input_size[dim]
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
input_ = input_.contiguous()
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
world_size = nccl_info.sp_size
rank = nccl_info.rank_within_group
dim = ctx.dim
input_size = ctx.input_size
sizes = [input_size] * world_size
grad_input_list = torch.split(grad_output, sizes, dim=dim)
grad_input = grad_input_list[rank]
return grad_input, None
def all_gather(input_: torch.Tensor, dim: int = 1):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return _AllGather.apply(input_, dim)
def prepare_sequence_parallel_data(hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask):
if nccl_info.sp_size == 1:
return (
hidden_states,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
)
def prepare(hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask):
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
encoder_hidden_states = all_to_all(encoder_hidden_states, scatter_dim=1, gather_dim=0)
attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0)
encoder_attention_mask = all_to_all(encoder_attention_mask, scatter_dim=1, gather_dim=0)
return (
hidden_states,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
)
sp_size = nccl_info.sp_size
frame = hidden_states.shape[2]
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
(
hidden_states,
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
) = prepare(
hidden_states,
encoder_hidden_states.repeat(1, sp_size, 1),
attention_mask.repeat(1, sp_size, 1, 1),
encoder_attention_mask.repeat(1, sp_size),
)
return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
def sp_parallel_dataloader_wrapper(dataloader, device, train_batch_size, sp_size, train_sp_batch_size):
while True:
for data_item in dataloader:
latents, cond, attn_mask, cond_mask = data_item
latents = latents.to(device)
cond = cond.to(device)
attn_mask = attn_mask.to(device)
cond_mask = cond_mask.to(device)
frame = latents.shape[2]
if frame == 1:
yield latents, cond, attn_mask, cond_mask
else:
latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(
latents, cond, attn_mask, cond_mask)
assert (train_batch_size * sp_size >=
train_sp_batch_size), "train_batch_size * sp_size should be greater than train_sp_batch_size"
for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
encoder_hidden_states = cond[st_idx:ed_idx]
attention_mask = attn_mask[st_idx:ed_idx]
encoder_attention_mask = cond_mask[st_idx:ed_idx]
yield (
latents[st_idx:ed_idx],
encoder_hidden_states,
attention_mask,
encoder_attention_mask,
)
import math
import random
from collections import Counter
from typing import List, Optional
import decord
import torch
import torch.utils
import torch.utils.data
from torch.nn import functional as F
from torch.utils.data import Sampler
IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG"]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
class DecordInit(object):
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
def __init__(self, num_threads=1):
self.num_threads = num_threads
self.ctx = decord.cpu(0)
def __call__(self, filename):
"""Perform the Decord initialization.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
reader = decord.VideoReader(filename, ctx=self.ctx, num_threads=self.num_threads)
return reader
def __repr__(self):
repr_str = (f"{self.__class__.__name__}("
f"sr={self.sr},"
f"num_threads={self.num_threads})")
return repr_str
def pad_to_multiple(number, ds_stride):
remainder = number % ds_stride
if remainder == 0:
return number
else:
padding = ds_stride - remainder
return number + padding
# TODO
class Collate:
def __init__(self, args):
self.batch_size = args.train_batch_size
self.group_frame = args.group_frame
self.group_resolution = args.group_resolution
self.max_height = args.max_height
self.max_width = args.max_width
self.ae_stride = args.ae_stride
self.ae_stride_t = args.ae_stride_t
self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride)
self.patch_size = args.patch_size
self.patch_size_t = args.patch_size_t
self.num_frames = args.num_frames
self.use_image_num = args.use_image_num
self.max_thw = (self.num_frames, self.max_height, self.max_width)
def package(self, batch):
batch_tubes = [i["pixel_values"] for i in batch] # b [c t h w]
input_ids = [i["input_ids"] for i in batch] # b [1 l]
cond_mask = [i["cond_mask"] for i in batch] # b [1 l]
return batch_tubes, input_ids, cond_mask
def __call__(self, batch):
batch_tubes, input_ids, cond_mask = self.package(batch)
ds_stride = self.ae_stride * self.patch_size
t_ds_stride = self.ae_stride_t * self.patch_size_t
pad_batch_tubes, attention_mask, input_ids, cond_mask = self.process(
batch_tubes,
input_ids,
cond_mask,
t_ds_stride,
ds_stride,
self.max_thw,
self.ae_stride_thw,
)
assert not torch.any(torch.isnan(pad_batch_tubes)), "after pad_batch_tubes"
return pad_batch_tubes, attention_mask, input_ids, cond_mask
def process(
self,
batch_tubes,
input_ids,
cond_mask,
t_ds_stride,
ds_stride,
max_thw,
ae_stride_thw,
):
# pad to max multiple of ds_stride
batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)]
assert len(batch_input_size) == self.batch_size
if self.group_frame or self.group_resolution or self.batch_size == 1: #
len_each_batch = batch_input_size
idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)])
count_dict = Counter(len_each_batch)
if len(count_dict) != 1:
sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])
pick_length = sorted_by_value[-1][0] # the highest frequency
candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]
random_select_batch = [
random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch))
]
print(
batch_input_size,
idx_length_dict,
count_dict,
sorted_by_value,
pick_length,
candidate_batch,
random_select_batch,
)
pick_idx = candidate_batch + random_select_batch
batch_tubes = [batch_tubes[i] for i in pick_idx]
batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)]
input_ids = [input_ids[i] for i in pick_idx] # b [1, l]
cond_mask = [cond_mask[i] for i in pick_idx] # b [1, l]
for i in range(1, self.batch_size):
assert batch_input_size[0] == batch_input_size[i]
max_t = max([i[1] for i in batch_input_size])
max_h = max([i[2] for i in batch_input_size])
max_w = max([i[3] for i in batch_input_size])
else:
max_t, max_h, max_w = max_thw
pad_max_t, pad_max_h, pad_max_w = (
pad_to_multiple(max_t - 1 + self.ae_stride_t, t_ds_stride),
pad_to_multiple(max_h, ds_stride),
pad_to_multiple(max_w, ds_stride),
)
pad_max_t = pad_max_t + 1 - self.ae_stride_t
each_pad_t_h_w = [[pad_max_t - i.shape[1], pad_max_h - i.shape[2], pad_max_w - i.shape[3]] for i in batch_tubes]
pad_batch_tubes = [
F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0)
for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes)
]
pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0)
max_tube_size = [pad_max_t, pad_max_h, pad_max_w]
max_latent_size = [
((max_tube_size[0] - 1) // ae_stride_thw[0] + 1),
max_tube_size[1] // ae_stride_thw[1],
max_tube_size[2] // ae_stride_thw[2],
]
valid_latent_size = [[
int(math.ceil((i[1] - 1) / ae_stride_thw[0])) + 1,
int(math.ceil(i[2] / ae_stride_thw[1])),
int(math.ceil(i[3] / ae_stride_thw[2])),
] for i in batch_input_size]
attention_mask = [
F.pad(
torch.ones(i, dtype=pad_batch_tubes.dtype),
(
0,
max_latent_size[2] - i[2],
0,
max_latent_size[1] - i[1],
0,
max_latent_size[0] - i[0],
),
value=0,
) for i in valid_latent_size
]
attention_mask = torch.stack(attention_mask) # b t h w
if self.batch_size == 1 or self.group_frame or self.group_resolution:
assert torch.all(attention_mask.bool())
input_ids = torch.stack(input_ids) # b 1 l
cond_mask = torch.stack(cond_mask) # b 1 l
return pad_batch_tubes, attention_mask, input_ids, cond_mask
def split_to_even_chunks(indices, lengths, num_chunks, batch_size):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
chunks = [indices[i::num_chunks] for i in range(num_chunks)]
else:
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
# return chunks
pad_chunks = []
for idx, chunk in enumerate(chunks):
if batch_size != len(chunk):
assert batch_size > len(chunk)
if len(chunk) != 0:
chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))]
else:
chunk = random.choice(pad_chunks)
print(chunks[idx], "->", chunk)
pad_chunks.append(chunk)
return pad_chunks
def group_frame_fun(indices, lengths):
# sort by num_frames
indices.sort(key=lambda i: lengths[i], reverse=True)
return indices
def megabatch_frame_alignment(megabatches, lengths):
aligned_magabatches = []
for _, megabatch in enumerate(megabatches):
assert len(megabatch) != 0
len_each_megabatch = [lengths[i] for i in megabatch]
idx_length_dict = dict([*zip(megabatch, len_each_megabatch)])
count_dict = Counter(len_each_megabatch)
# mixed frame length, align megabatch inside
if len(count_dict) != 1:
sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1])
pick_length = sorted_by_value[-1][0] # the highest frequency
candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length]
random_select_batch = [
random.choice(candidate_batch) for i in range(len(idx_length_dict) - len(candidate_batch))
]
aligned_magabatch = candidate_batch + random_select_batch
aligned_magabatches.append(aligned_magabatch)
# already aligned megabatches
else:
aligned_magabatches.append(megabatch)
return aligned_magabatches
def get_length_grouped_indices(
lengths,
batch_size,
world_size,
generator=None,
group_frame=False,
group_resolution=False,
seed=42,
):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
if generator is None:
generator = torch.Generator().manual_seed(seed) # every rank will generate a fixed order but random index
indices = torch.randperm(len(lengths), generator=generator).tolist()
# sort dataset according to frame
indices = group_frame_fun(indices, lengths)
# chunk dataset to megabatches
megabatch_size = world_size * batch_size
megabatches = [indices[i:i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
# make sure the length in each magabatch is align with each other
megabatches = megabatch_frame_alignment(megabatches, lengths)
# aplit aligned megabatch into batches
megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches]
# random megabatches to do video-image mix training
indices = torch.randperm(len(megabatches), generator=generator).tolist()
shuffled_megabatches = [megabatches[i] for i in indices]
# expand indices and return
return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
rank: int,
world_size: int,
lengths: Optional[List[int]] = None,
group_frame=False,
group_resolution=False,
generator=None,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.rank = rank
self.world_size = world_size
self.lengths = lengths
self.group_frame = group_frame
self.group_resolution = group_resolution
self.generator = generator
def __len__(self):
return len(self.lengths)
def __iter__(self):
indices = get_length_grouped_indices(
self.lengths,
self.batch_size,
self.world_size,
group_frame=self.group_frame,
group_resolution=self.group_resolution,
generator=self.generator,
)
def distributed_sampler(lst, rank, batch_size, world_size):
result = []
index = rank * batch_size
while index < len(lst):
result.extend(lst[index:index + batch_size])
index += batch_size * world_size
return result
indices = distributed_sampler(indices, self.rank, self.batch_size, self.world_size)
return iter(indices)
import platform
import accelerate
import peft
import torch
import transformers
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "1.2.0"
if __name__ == "__main__":
info = {
"FastVideo version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann # codespell:ignore
try:
import bitsandbytes
info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
pass
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
# ruff: noqa: E731
import functools
from functools import partial
import torch
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (CheckpointImpl, apply_activation_checkpointing,
checkpoint_wrapper)
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformerBlock
from fastvideo.utils.load import get_no_split_modules
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, MochiTransformerBlock)
def apply_fsdp_checkpointing(model, no_split_modules, p=1):
# https://github.com/foundation-model-stack/fms-fsdp/blob/408c7516d69ea9b6bcd4c0f5efab26c0f64b3c2d/fms_fsdp/policies/ac_handler.py#L16
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print("--> applying fdsp activation checkpointing...")
block_idx = 0
cut_off = 1 / 2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p = eval(p) if isinstance(p, str) else p
def selective_checkpointing(submodule):
nonlocal block_idx
nonlocal cut_off
if isinstance(submodule, no_split_modules):
block_idx += 1
if block_idx * p >= cut_off:
cut_off += 1
return True
return False
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)
def get_mixed_precision(master_weight_type="fp32"):
weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16
mixed_precision = MixedPrecision(
param_dtype=weight_type,
# Gradient communication precision.
reduce_dtype=weight_type,
# Buffer precision.
buffer_dtype=weight_type,
cast_forward_inputs=False,
)
return mixed_precision
def get_dit_fsdp_kwargs(
transformer,
sharding_strategy,
use_lora=False,
cpu_offload=False,
master_weight_type="fp32",
):
no_split_modules = get_no_split_modules(transformer)
if use_lora:
auto_wrap_policy = fsdp_auto_wrap_policy
else:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=no_split_modules,
)
# we use float32 for fsdp but autocast during training
mixed_precision = get_mixed_precision(master_weight_type)
if sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif sharding_strategy == "hybrid_full":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif sharding_strategy == "none":
sharding_strategy = ShardingStrategy.NO_SHARD
auto_wrap_policy = None
elif sharding_strategy == "hybrid_zero2":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
device_id = torch.cuda.current_device()
cpu_offload = (torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
"cpu_offload": cpu_offload,
}
# Add LoRA-specific settings when LoRA is enabled
if use_lora:
fsdp_kwargs.update({
"use_orig_params": False, # Required for LoRA memory savings
"sync_module_states": True,
})
return fsdp_kwargs, no_split_modules
def get_discriminator_fsdp_kwargs(master_weight_type="fp32"):
auto_wrap_policy = None
# Use existing mixed precision settings
mixed_precision = get_mixed_precision(master_weight_type)
sharding_strategy = ShardingStrategy.NO_SHARD
device_id = torch.cuda.current_device()
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"mixed_precision": mixed_precision,
"sharding_strategy": sharding_strategy,
"device_id": device_id,
"limit_all_gathers": True,
}
return fsdp_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment