Commit 319f1f41 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix A14B load quant model bug and tae bug (#277)

* [Fix] Fix A14B load quant model bug and tae bug

* [Fix] Fix A14B load quant model bug and tae bug
parent 375a6f77
...@@ -53,7 +53,7 @@ class HunyuanModel: ...@@ -53,7 +53,7 @@ class HunyuanModel:
return weight_dict return weight_dict
def _load_quant_ckpt(self): def _load_quant_ckpt(self):
ckpt_path = self.config.dit_quantized_ckpt ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}") logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"): if ckpt_path.endswith(".pth"):
......
...@@ -80,8 +80,6 @@ class WanModel: ...@@ -80,8 +80,6 @@ class WanModel:
self.dit_quantized_ckpt = None self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False) assert not self.config.get("lazy_load", False)
self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
......
...@@ -386,6 +386,7 @@ class MultiModelStruct: ...@@ -386,6 +386,7 @@ class MultiModelStruct:
if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep: if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0] self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=0) self.to_cuda(model_index=0)
elif self.cur_model_index == 1: # 1 -> 0 elif self.cur_model_index == 1: # 1 -> 0
...@@ -395,6 +396,7 @@ class MultiModelStruct: ...@@ -395,6 +396,7 @@ class MultiModelStruct:
else: else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1] self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=1) self.to_cuda(model_index=1)
elif self.cur_model_index == 0: # 0 -> 1 elif self.cur_model_index == 0: # 0 -> 1
......
...@@ -158,7 +158,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): ...@@ -158,7 +158,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV(nn.Module): class TAEHV(nn.Module):
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), patch_size=1, latent_channels=16): def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), patch_size=1, latent_channels=16, model_type="wan21"):
"""Initialize pretrained TAEHV from the given checkpoint. """Initialize pretrained TAEHV from the given checkpoint.
Arg: Arg:
...@@ -173,7 +173,10 @@ class TAEHV(nn.Module): ...@@ -173,7 +173,10 @@ class TAEHV(nn.Module):
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.image_channels = 3 self.image_channels = 3
self.is_cogvideox = checkpoint_path is not None and "taecvx" in checkpoint_path self.is_cogvideox = checkpoint_path is not None and "taecvx" in checkpoint_path
if checkpoint_path is not None and "taew2_2" in checkpoint_path: # if checkpoint_path is not None and "taew2_2" in checkpoint_path:
# self.patch_size, self.latent_channels = 2, 48
if model_type == "wan22":
self.patch_size, self.latent_channels = 2, 48 self.patch_size, self.latent_channels = 2, 48
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
conv(self.image_channels * self.patch_size**2, 64), conv(self.image_channels * self.patch_size**2, 64),
......
...@@ -841,7 +841,9 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa ...@@ -841,7 +841,9 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
class Wan2_2_VAE: class Wan2_2_VAE:
def __init__(self, z_dim=48, c_dim=160, vae_pth=None, dim_mult=[1, 2, 4, 4], temperal_downsample=[False, True, True], dtype=torch.float, device="cuda", cpu_offload=False, offload_cache=False): def __init__(
self, z_dim=48, c_dim=160, vae_pth=None, dim_mult=[1, 2, 4, 4], temperal_downsample=[False, True, True], dtype=torch.float, device="cuda", cpu_offload=False, offload_cache=False, **kwargs
):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
......
...@@ -80,7 +80,7 @@ class Wan2_2_VAE_tiny(nn.Module): ...@@ -80,7 +80,7 @@ class Wan2_2_VAE_tiny(nn.Module):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.taehv = TAEHV(vae_pth).to(self.dtype) self.taehv = TAEHV(vae_pth, model_type="wan22").to(self.dtype)
self.need_scaled = need_scaled self.need_scaled = need_scaled
if self.need_scaled: if self.need_scaled:
self.latents_mean = [ self.latents_mean = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment