Commit ed1a937b authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Update load weights params list (#317)

parent 3b896f9c
......@@ -540,6 +540,7 @@ class T5EncoderModel:
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
load_from_rank0=False,
):
self.text_len = text_len
self.dtype = dtype
......@@ -570,7 +571,7 @@ class T5EncoderModel:
.requires_grad_(False)
)
weights_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload)
weights_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
model.load_state_dict(weights_dict)
self.model = model
......
......@@ -418,7 +418,7 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False):
self.dtype = dtype
self.device = device
self.quantized = clip_quantized
......@@ -435,7 +435,7 @@ class CLIPModel:
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual")
weight_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, remove_key="textual", load_from_rank0=load_from_rank0)
self.model.load_state_dict(weight_dict)
def visual(self, videos):
......
......@@ -38,7 +38,8 @@ class WanAudioModel(WanModel):
self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)
adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
load_from_rank0 = self.config.get("load_from_rank0", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
if not dist.is_initialized() and not adapter_offload:
for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
......
......@@ -735,7 +735,8 @@ class WanAudioRunner(WanRunner): # type:ignore
)
audio_adapter.to(device)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca")
load_from_rank0 = self.config.get("load_from_rank0", False)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE())
......
......@@ -89,6 +89,7 @@ class WanRunner(DefaultRunner):
quant_scheme=clip_quant_scheme,
cpu_offload=clip_offload,
use_31_block=self.config.get("use_31_block", True),
load_from_rank0=self.config.get("load_from_rank0", False),
)
return image_encoder
......@@ -130,6 +131,7 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
load_from_rank0=self.config.get("load_from_rank0", False),
)
text_encoders = [text_encoder]
return text_encoders
......@@ -149,6 +151,7 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
if self.config.task not in ["i2v", "flf2v", "vace"]:
return None
......@@ -170,6 +173,7 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
if self.config.get("use_tiny_vae", False):
tiny_vae_path = find_torch_model_path(self.config, "tiny_vae_path", self.tiny_vae_name)
......
......@@ -761,7 +761,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -782,7 +782,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
model = WanVAE_(**cfg)
# load checkpoint
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
for k in weights_dict.keys():
if weights_dict[k].dtype != dtype:
weights_dict[k] = weights_dict[k].to(dtype)
......@@ -802,6 +802,7 @@ class WanVAE:
use_tiling=False,
cpu_offload=False,
use_2d_split=True,
load_from_rank0=False,
):
self.dtype = dtype
self.device = device
......@@ -888,7 +889,7 @@ class WanVAE:
}
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype).eval().requires_grad_(False).to(device).to(dtype)
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table:
......
......@@ -812,7 +812,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, **kwargs):
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offload=False, dtype=torch.float32, load_from_rank0=False, **kwargs):
# params
cfg = dict(
dim=dim,
......@@ -831,7 +831,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
# load checkpoint
logging.info(f"loading {pretrained_path}")
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
for k in weights_dict.keys():
if weights_dict[k].dtype != dtype:
weights_dict[k] = weights_dict[k].to(dtype)
......@@ -842,7 +842,18 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
class Wan2_2_VAE:
def __init__(
self, z_dim=48, c_dim=160, vae_pth=None, dim_mult=[1, 2, 4, 4], temperal_downsample=[False, True, True], dtype=torch.float, device="cuda", cpu_offload=False, offload_cache=False, **kwargs
self,
z_dim=48,
c_dim=160,
vae_pth=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
device="cuda",
cpu_offload=False,
offload_cache=False,
load_from_rank0=False,
**kwargs,
):
self.dtype = dtype
self.device = device
......@@ -961,7 +972,9 @@ class Wan2_2_VAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype)
_video_vae(
pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0
)
.eval()
.requires_grad_(False)
.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment