Unverified Commit 49aff300 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files
parent c47dc6e8
......@@ -15,11 +15,12 @@
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_cpu_offload": false,
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f",
"clip_cpu_offload": false,
"clip_quantized": true,
"clip_quant_scheme": "fp8-q8f",
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
......
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": [
3.5,
3.5
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [
1000,
750,
500,
250
],
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f"
}
......@@ -38,7 +38,7 @@ class WanRunner(DefaultRunner):
super().__init__(config)
self.vae_cls = WanVAE
self.tiny_vae_cls = WanVAE_tiny
self.vae_name = "Wan2.1_VAE.pth"
self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth")
self.tiny_vae_name = "taew2_1.pth"
def load_transformer(self):
......@@ -73,7 +73,7 @@ class WanRunner(DefaultRunner):
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
clip_original_ckpt = None
else:
......@@ -154,6 +154,7 @@ class WanRunner(DefaultRunner):
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
"use_lightvae": self.config.get("use_lightvae", False),
}
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None
......@@ -174,6 +175,7 @@ class WanRunner(DefaultRunner):
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"use_lightvae": self.config.get("use_lightvae", False),
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
......
......@@ -263,16 +263,7 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -283,6 +274,7 @@ class Encoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [1] + dim_mult]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0
# init block
......@@ -375,16 +367,7 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -395,6 +378,8 @@ class Decoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
......@@ -498,16 +483,7 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -534,6 +510,7 @@ class WanVAE_(nn.Module):
attn_scales,
self.temperal_downsample,
dropout,
pruning_rate,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
......@@ -545,6 +522,7 @@ class WanVAE_(nn.Module):
attn_scales,
self.temperal_upsample,
dropout,
pruning_rate,
)
def forward(self, x):
......@@ -739,23 +717,6 @@ class WanVAE_(nn.Module):
self.clear_cache()
return out
def cached_decode(self, z, scale):
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
......@@ -778,7 +739,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, pruning_rate=0.0, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -791,6 +752,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0,
pruning_rate=pruning_rate,
)
cfg.update(**kwargs)
......@@ -820,6 +782,7 @@ class WanVAE:
cpu_offload=False,
use_2d_split=True,
load_from_rank0=False,
use_lightvae=False,
):
self.dtype = dtype
self.device = device
......@@ -827,6 +790,10 @@ class WanVAE:
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
self.use_2d_split = use_2d_split
if use_lightvae:
pruning_rate = 0.75 # 0.75
else:
pruning_rate = 0.0
mean = [
-0.7571,
......@@ -906,7 +873,13 @@ class WanVAE:
}
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
.eval()
.requires_grad_(False)
.to(device)
.to(dtype)
)
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment