Commit 1994ffb1 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support offload cache for wan2.2_vae

Support offload cache for wan2.2_vae
parents 64948a2e 83a73049
{
"infer_steps": 50,
"target_video_length": 121,
"text_len": 512,
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
"fps": 24,
"use_image_encoder": false,
"cpu_offload": true,
"offload_granularity": "model",
"vae_offload_cache": true
}
{
"infer_steps": 50,
"target_video_length": 121,
"text_len": 512,
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
"fps": 24,
"cpu_offload": true,
"offload_granularity": "model",
"vae_offload_cache": true
}
...@@ -68,13 +68,13 @@ class WanRunner(DefaultRunner): ...@@ -68,13 +68,13 @@ class WanRunner(DefaultRunner):
assert clip_quant_scheme is not None assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth" clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name, tmp_clip_quant_scheme) clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
clip_original_ckpt = None clip_original_ckpt = None
else: else:
clip_quantized_ckpt = None clip_quantized_ckpt = None
clip_quant_scheme = None clip_quant_scheme = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name, "original") clip_original_ckpt = find_torch_model_path(self.config, "clip_original_ckpt", clip_model_name)
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
...@@ -90,7 +90,7 @@ class WanRunner(DefaultRunner): ...@@ -90,7 +90,7 @@ class WanRunner(DefaultRunner):
def load_text_encoder(self): def load_text_encoder(self):
# offload config # offload config
t5_offload = self.config.get("t5_cpu_offload", False) t5_offload = self.config.get("t5_cpu_offload", self.config.get("cpu_offload"))
if t5_offload: if t5_offload:
t5_device = torch.device("cpu") t5_device = torch.device("cpu")
else: else:
...@@ -103,14 +103,14 @@ class WanRunner(DefaultRunner): ...@@ -103,14 +103,14 @@ class WanRunner(DefaultRunner):
assert t5_quant_scheme is not None assert t5_quant_scheme is not None
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0] tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth" t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name, tmp_t5_quant_scheme) t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
t5_original_ckpt = None t5_original_ckpt = None
tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl") tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
else: else:
t5_quant_scheme = None t5_quant_scheme = None
t5_quantized_ckpt = None t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name, "original") t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl") tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
...@@ -121,7 +121,7 @@ class WanRunner(DefaultRunner): ...@@ -121,7 +121,7 @@ class WanRunner(DefaultRunner):
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
shard_fn=None, shard_fn=None,
cpu_offload=t5_offload, cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), offload_granularity=self.config.get("t5_offload_granularity", "model"), # support ['model', 'block']
t5_quantized=t5_quantized, t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt, t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme, quant_scheme=t5_quant_scheme,
...@@ -131,12 +131,20 @@ class WanRunner(DefaultRunner): ...@@ -131,12 +131,20 @@ class WanRunner(DefaultRunner):
return text_encoders return text_encoders
def load_vae_encoder(self): def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device, "device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group, "seq_p_group": self.seq_p_group,
"cpu_offload": vae_offload,
} }
if self.config.task != "i2v": if self.config.task != "i2v":
return None return None
...@@ -144,11 +152,19 @@ class WanRunner(DefaultRunner): ...@@ -144,11 +152,19 @@ class WanRunner(DefaultRunner):
return WanVAE(**vae_config) return WanVAE(**vae_config)
def load_vae_decoder(self): def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.1_VAE.pth"),
"device": self.init_device, "device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1, "parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
} }
if self.config.get("use_tiny_vae", False): if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth") tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", "taew2_1.pth")
...@@ -398,17 +414,33 @@ class Wan22DenseRunner(WanRunner): ...@@ -398,17 +414,33 @@ class Wan22DenseRunner(WanRunner):
super().__init__(config) super().__init__(config)
def load_vae_decoder(self): def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": self.init_device, "device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
} }
vae_decoder = Wan2_2_VAE(**vae_config) vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder return vae_decoder
def load_vae_encoder(self): def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"), "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": self.init_device, "device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
} }
if self.config.task != "i2v": if self.config.task != "i2v":
return None return None
......
...@@ -799,11 +799,13 @@ class WanVAE: ...@@ -799,11 +799,13 @@ class WanVAE:
parallel=False, parallel=False,
use_tiling=False, use_tiling=False,
seq_p_group=None, seq_p_group=None,
cpu_offload=False,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.parallel = parallel self.parallel = parallel
self.use_tiling = use_tiling self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
mean = [ mean = [
-0.7571, -0.7571,
...@@ -940,8 +942,8 @@ class WanVAE: ...@@ -940,8 +942,8 @@ class WanVAE:
return images return images
def decode(self, zs, generator, config): def decode(self, zs, **args):
if config.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
if self.parallel: if self.parallel:
...@@ -962,7 +964,7 @@ class WanVAE: ...@@ -962,7 +964,7 @@ class WanVAE:
else: else:
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload: if self.cpu_offload:
images = images.cpu().float() images = images.cpu().float()
self.to_cpu() self.to_cpu()
......
...@@ -619,7 +619,7 @@ class Decoder3d(nn.Module): ...@@ -619,7 +619,7 @@ class Decoder3d(nn.Module):
CausalConv3d(out_dim, 12, 3, padding=1), CausalConv3d(out_dim, 12, 3, padding=1),
) )
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, offload_cache=False):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
...@@ -639,14 +639,24 @@ class Decoder3d(nn.Module): ...@@ -639,14 +639,24 @@ class Decoder3d(nn.Module):
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if isinstance(layer, ResidualBlock) and feat_cache is not None:
idx = feat_idx[0]
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
if offload_cache:
for _idx in range(idx, feat_idx[0]):
if isinstance(feat_cache[_idx], torch.Tensor):
feat_cache[_idx] = feat_cache[_idx].cpu()
else: else:
x = layer(x) x = layer(x)
## upsamples ## upsamples
for layer in self.upsamples: for layer in self.upsamples:
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0]
x = layer(x, feat_cache, feat_idx, first_chunk) x = layer(x, feat_cache, feat_idx, first_chunk)
if offload_cache:
for _idx in range(idx, feat_idx[0]):
if isinstance(feat_cache[_idx], torch.Tensor):
feat_cache[_idx] = feat_cache[_idx].cpu()
else: else:
x = layer(x) x = layer(x)
...@@ -664,7 +674,7 @@ class Decoder3d(nn.Module): ...@@ -664,7 +674,7 @@ class Decoder3d(nn.Module):
dim=2, dim=2,
) )
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x.cpu() if offload_cache else cache_x
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
x = layer(x) x = layer(x)
...@@ -755,7 +765,7 @@ class WanVAE_(nn.Module): ...@@ -755,7 +765,7 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
return mu return mu
def decode(self, z, scale): def decode(self, z, scale, offload_cache=False):
self.clear_cache() self.clear_cache()
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
...@@ -766,18 +776,9 @@ class WanVAE_(nn.Module): ...@@ -766,18 +776,9 @@ class WanVAE_(nn.Module):
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, offload_cache=offload_cache)
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else: else:
out_ = self.decoder( out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, offload_cache=offload_cache)
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2) out = unpatchify(out, patch_size=2)
self.clear_cache() self.clear_cache()
...@@ -830,18 +831,11 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): ...@@ -830,18 +831,11 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
class Wan2_2_VAE: class Wan2_2_VAE:
def __init__( def __init__(self, z_dim=48, c_dim=160, vae_pth=None, dim_mult=[1, 2, 4, 4], temperal_downsample=[False, True, True], dtype=torch.float, device="cuda", cpu_offload=False, offload_cache=False):
self,
z_dim=48,
c_dim=160,
vae_pth=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
device="cuda",
):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.cpu_offload = cpu_offload
self.offload_cache = offload_cache
self.mean = torch.tensor( self.mean = torch.tensor(
[ [
...@@ -991,11 +985,11 @@ class Wan2_2_VAE: ...@@ -991,11 +985,11 @@ class Wan2_2_VAE:
self.to_cpu() self.to_cpu()
return out return out
def decode(self, zs, generator, config): def decode(self, zs, **args):
if config.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) images = self.model.decode(zs.unsqueeze(0), self.scale, offload_cache=self.offload_cache if self.cpu_offload else False).float().clamp_(-1, 1)
if config.cpu_offload: if self.cpu_offload:
images = images.cpu().float() images = images.cpu().float()
self.to_cpu() self.to_cpu()
return images return images
...@@ -258,7 +258,7 @@ def save_to_video( ...@@ -258,7 +258,7 @@ def save_to_video(
raise ValueError(f"Unknown save method: {method}") raise ValueError(f"Unknown save method: {method}")
def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["original", "fp8", "int8"]): def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None: if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key) return config.get(ckpt_config_key)
...@@ -277,7 +277,7 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=[" ...@@ -277,7 +277,7 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8"]): def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None: if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key) return config.get(ckpt_config_key)
......
...@@ -5,7 +5,7 @@ lightx2v_path= ...@@ -5,7 +5,7 @@ lightx2v_path=
model_path= model_path=
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment