Commit 7367d6c8 authored by helloyongyang's avatar helloyongyang
Browse files

remove unsed seq_p_group

parent 99a6f046
......@@ -540,7 +540,6 @@ class T5EncoderModel:
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
seq_p_group=None,
):
self.text_len = text_len
self.dtype = dtype
......
......@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, seq_p_group=None):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True):
self.dtype = dtype
self.device = device
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
self.seq_p_group = seq_p_group
if self.quantized:
self.checkpoint_path = clip_quantized_ckpt
......
......@@ -16,8 +16,8 @@ class WanAudioModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
......
......@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
super().__init__(model_path, config, device, seq_p_group)
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
......@@ -74,10 +74,6 @@ class WanPreInfer:
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
# wan2.2_moe会对t做扩展,我们发现这里做不做影响不大,而且做了拓展会增加耗时,目前忠实原作代码,后续可以考虑去掉
if self.config["model_cls"] == "wan2.2_moe":
t = t.expand(seq_lens[0])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
......
......@@ -41,12 +41,16 @@ class WanModel:
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, seq_p_group=None):
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.seq_p_group = seq_p_group
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
......@@ -390,11 +394,11 @@ class WanModel:
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"):
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
if padding_size > 0:
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
embed = F.pad(embed, (0, 0, 0, padding_size))
# if self.config["model_cls"] == "wan2.2":
# padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
# if padding_size > 0:
# embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
# embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out.x = x
pre_infer_out.embed = embed
......
......@@ -435,8 +435,8 @@ class WanAudioRunner(WanRunner): # type:ignore
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
if self.model.transformer_infer.seq_p_group is not None:
seq_p_group = self.model.transformer_infer.seq_p_group
if self.config["seq_parallel"]:
seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
seq_p_group = None
......@@ -619,7 +619,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
logger.info(f"Loaded base model: {self.config.model_path}")
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......
......@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......
......@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner):
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
......@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs:
......@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
self.config,
self.init_device,
self.seq_p_group,
)
if use_low_lora:
......
......@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video
class WanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
device_mesh = self.config.get("device_mesh")
if device_mesh is not None:
self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def load_transformer(self):
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
self.seq_p_group,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
......@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner):
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
seq_p_group=self.seq_p_group,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
use_31_block=self.config.get("use_31_block", True),
)
......@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner):
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
seq_p_group=self.seq_p_group,
)
text_encoders = [text_encoder]
return text_encoders
......@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner):
"device": vae_device,
"parallel": self.config.parallel and self.config.parallel.get("vae_p_size", False) and self.config.parallel.vae_p_size > 1,
"use_tiling": self.config.get("use_tiling_vae", False),
"seq_p_group": self.seq_p_group,
"cpu_offload": vae_offload,
}
if self.config.task != "i2v":
......
......@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None, cpu_offload=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -795,7 +795,6 @@ class WanVAE:
device="cuda",
parallel=False,
use_tiling=False,
seq_p_group=None,
cpu_offload=False,
):
self.dtype = dtype
......@@ -845,7 +844,7 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, seq_p_group=seq_p_group, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
def current_device(self):
return next(self.model.parameters()).device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment