"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "97d08556373c0bc98badf4f2b17a114d97a70124"
Commit 84e756e9 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Support vae bf16 and add 4090 best config (#256)



* [Feat] Support vae bf16 and add 4090 best config

* [Feat] Support vae bf16 and add 4090 best config

* [Feat] Support vae bf16 and add 4090 best config

* [Feat] Support vae bf16 and add 4090 best config

* update dtype

---------
Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent c93c756c
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 120,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"adaptive_resize": true,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio_val": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f",
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"vae_cpu_offload": false,
"use_tiling_vae": true,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
}
}
......@@ -247,6 +247,20 @@ class Q8FQuantLinearInt8(nn.Module):
)
return output_tensor
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class Q8FQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
......@@ -277,3 +291,17 @@ class Q8FQuantLinearFp8(nn.Module):
out_dtype=torch.bfloat16,
)
return output_tensor
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -231,7 +231,7 @@ class DefaultRunner(BaseRunner):
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents)
images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
......
......@@ -301,7 +301,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.vae_encoder = self.load_vae_encoder()
img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float)).to(GET_DTYPE())
vae_encoder_out = self.vae_encoder.encode(img.to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
......@@ -333,7 +333,6 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Prepare previous latents for conditioning"""
device = torch.device("cuda")
dtype = GET_DTYPE()
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)
......@@ -354,12 +353,12 @@ class WanAudioRunner(WanRunner): # type:ignore
_, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio":
if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
else:
prev_latents = None
prev_mask = self.model.scheduler.mask
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
......@@ -561,7 +560,7 @@ class WanAudioRunner(WanRunner): # type:ignore
)
audio_adapter.to(device)
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) == "fp8":
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
model_name = "audio_adapter_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
model_name = "audio_adapter_int8.safetensors"
......
......@@ -138,6 +138,7 @@ class WanRunner(DefaultRunner):
"parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
}
if self.config.task not in ["i2v", "flf2v", "vace"]:
return None
......@@ -158,6 +159,7 @@ class WanRunner(DefaultRunner):
"parallel": self.config.parallel,
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
......@@ -309,7 +311,7 @@ class WanRunner(DefaultRunner):
dim=1,
).cuda()
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0))
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0).to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
......@@ -444,6 +446,7 @@ class Wan22DenseRunner(WanRunner):
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dtype": GET_DTYPE(),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
......@@ -460,6 +463,7 @@ class Wan22DenseRunner(WanRunner):
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
"dtype": GET_DTYPE(),
}
if self.config.task not in ["i2v", "flf2v"]:
return None
......@@ -494,5 +498,5 @@ class Wan22DenseRunner(WanRunner):
return vae_encoder_out
def get_vae_encoder_output(self, img):
z = self.vae_encoder.encode(img)
z = self.vae_encoder.encode(img.to(GET_DTYPE()))
return z
......@@ -8,6 +8,7 @@ from PIL import Image
from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor
from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -96,22 +97,22 @@ class WanVaceRunner(WanRunner):
assert len(frames) == len(ref_images)
if masks is None:
latents = [self.vae_encoder.encode(frame.unsqueeze(0)) for frame in frames]
latents = [self.vae_encoder.encode(frame.unsqueeze(0).to(GET_DTYPE())) for frame in frames]
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = [self.vae_encoder.encode(inact.unsqueeze(0)) for inact in inactive]
reactive = [self.vae_encoder.encode(react.unsqueeze(0)) for react in reactive]
inactive = [self.vae_encoder.encode(inact.unsqueeze(0).to(GET_DTYPE())) for inact in inactive]
reactive = [self.vae_encoder.encode(react.unsqueeze(0).to(GET_DTYPE())) for react in reactive]
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0)) for ref in refs]
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
else:
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0)) for ref in refs]
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0).to(GET_DTYPE())) for ref in refs]
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
......@@ -169,7 +170,7 @@ class WanVaceRunner(WanRunner):
if refs is not None:
latents = latents[:, len(refs) :, :, :]
images = self.vae_decoder.decode(latents)
images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
......
......@@ -64,7 +64,7 @@ class Upsample(nn.Upsample):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
return super().forward(x)
class Resample(nn.Module):
......@@ -761,7 +761,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -783,6 +783,9 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
# load checkpoint
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
for k in weights_dict.keys():
if weights_dict[k].dtype != dtype:
weights_dict[k] = weights_dict[k].to(dtype)
model.load_state_dict(weights_dict, assign=True)
return model
......@@ -846,7 +849,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype).eval().requires_grad_(False).to(device).to(dtype)
def current_device(self):
return next(self.model.parameters()).device
......@@ -902,9 +905,9 @@ class WanVAE:
video_chunk = video[:, :, :, :, start_idx:end_idx].contiguous()
if self.use_tiling:
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale).float()
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
else:
encoded_chunk = self.model.encode(video_chunk, self.scale).float()
encoded_chunk = self.model.encode(video_chunk, self.scale)
if cur_rank == 0:
if split_dim == 3:
......@@ -951,14 +954,14 @@ class WanVAE:
else:
logger.info("Fall back to naive encode mode")
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
out = self.model.tiled_encode(video, self.scale).squeeze(0)
else:
out = self.model.encode(video, self.scale).float().squeeze(0)
out = self.model.encode(video, self.scale).squeeze(0)
else:
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
out = self.model.tiled_encode(video, self.scale).squeeze(0)
else:
out = self.model.encode(video, self.scale).float().squeeze(0)
out = self.model.encode(video, self.scale).squeeze(0)
if self.cpu_offload:
self.to_cpu()
......@@ -986,7 +989,7 @@ class WanVAE:
zs = zs[:, :, :, cur_rank * splited_chunk_len - padding_size : (cur_rank + 1) * splited_chunk_len + padding_size].contiguous()
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images = decode_func(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
if cur_rank == 0:
if split_dim == 2:
......@@ -1028,13 +1031,13 @@ class WanVAE:
images = self.decode_dist(zs, world_size, cur_rank, split_dim=2)
else:
logger.info("Fall back to naive decode mode")
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
else:
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
images = decode_func(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
if self.cpu_offload:
images = images.cpu().float()
images = images.cpu()
self.to_cpu()
return images
......@@ -812,7 +812,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, **kwargs):
# params
cfg = dict(
dim=dim,
......@@ -832,6 +832,9 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
# load checkpoint
logging.info(f"loading {pretrained_path}")
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
for k in weights_dict.keys():
if weights_dict[k].dtype != dtype:
weights_dict[k] = weights_dict[k].to(dtype)
model.load_state_dict(weights_dict, assign=True)
return model
......@@ -956,17 +959,11 @@ class Wan2_2_VAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = (
_video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
dim=c_dim,
dim_mult=dim_mult,
temperal_downsample=temperal_downsample,
cpu_offload=cpu_offload,
)
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype)
.eval()
.requires_grad_(False)
.to(device)
.to(dtype)
)
def to_cpu(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment