Commit 95b58beb authored by helloyongyang's avatar helloyongyang
Browse files

update parallel

parent f05a99da
import glob
import os
import torch
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.model import WanModel
......@@ -27,87 +24,6 @@ class WanAudioModel(WanModel):
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
......
......@@ -18,30 +18,11 @@ class WanAudioPostInfer(WanPostInfer):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes, valid_patch_length):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes)
def infer(self, weights, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length]
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache()
return [u.float() for u in x]
......
......@@ -3,6 +3,7 @@ import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d
......@@ -126,4 +127,14 @@ class WanAudioPreInfer(WanPreInfer):
del context_clip
torch.cuda.empty_cache()
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length)
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=x_grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
audio_dit_blocks=audio_dit_blocks,
valid_patch_length=valid_patch_length,
)
from dataclasses import dataclass
from typing import List
import torch
@dataclass
class WanPreInferModuleOutput:
embed: torch.Tensor
grid_sizes: torch.Tensor
x: torch.Tensor
embed0: torch.Tensor
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
audio_dit_blocks: List = None
valid_patch_length: int = None
......@@ -10,35 +10,15 @@ class WanPostInfer:
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes)
def infer(self, weights, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache()
return [u.float() for u in x]
......
......@@ -2,6 +2,7 @@ import torch
from lightx2v.utils.envs import *
from .module_io import WanPreInferModuleOutput
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d
......@@ -132,8 +133,13 @@ class WanPreInfer:
if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache()
return (
embed,
grid_sizes,
(x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context),
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
)
......@@ -9,7 +9,7 @@ from lightx2v.common.offload.manager import (
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio, compute_freqs_audio_dist, compute_freqs_dist
class WanTransformerInfer(BaseTransformerInfer):
......@@ -33,7 +33,11 @@ class WanTransformerInfer(BaseTransformerInfer):
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.seq_p_group = None
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
......@@ -86,15 +90,56 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
if self.config["seq_parallel"]:
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def infer(self, weights, pre_infer_out):
x = self.infer_func(
weights,
pre_infer_out.grid_sizes,
pre_infer_out.embed,
pre_infer_out.x,
pre_infer_out.embed0,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
pre_infer_out.context,
pre_infer_out.audio_dit_blocks,
)
return self._infer_post_blocks(weights, x, pre_infer_out.embed)
def _infer_post_blocks(self, weights, x, e):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
x = weights.norm.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype)
x.mul_(1 + e[1].squeeze()).add_(e[0].squeeze())
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.infer_dtype)
x = weights.head.apply(x)
if self.clean_cuda_cache:
del e
torch.cuda.empty_cache()
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
......
import torch
import torch.distributed as dist
from lightx2v.utils.envs import *
......@@ -39,6 +40,52 @@ def compute_freqs_audio(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_audio_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
......
......@@ -7,7 +7,6 @@ from loguru import logger
from safetensors import safe_open
from lightx2v.common.ops.attn import MaskMap
from lightx2v.models.networks.wan.infer.dist_infer.transformer_infer import WanTransformerDistInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferAdaCaching,
WanTransformerInferCustomCaching,
......@@ -83,27 +82,25 @@ class WanModel:
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.seq_p_group is not None:
self.transformer_infer_class = WanTransformerDistInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
else:
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = WanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = WanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = WanTransformerInferCustomCaching
elif self.config["feature_caching"] == "FirstBlock":
self.transformer_infer_class = WanTransformerInferFirstBlock
elif self.config["feature_caching"] == "DualBlock":
self.transformer_infer_class = WanTransformerInferDualBlock
elif self.config["feature_caching"] == "DynamicBlock":
self.transformer_infer_class = WanTransformerInferDynamicBlock
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _should_load_weights(self):
"""Determine if current rank should load weights from disk."""
......@@ -296,16 +293,7 @@ class WanModel:
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if self.seq_p_group is not None:
self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
else:
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel
else:
self.infer_func = self.infer_wo_cfg_parallel
self.transformer_infer = self.transformer_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -325,10 +313,6 @@ class WanModel:
@torch.no_grad()
def infer(self, inputs):
return self.infer_func(inputs)
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
......@@ -341,26 +325,31 @@ class WanModel:
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["cfg_parallel"]:
# ==================== CFG Parallel Processing ====================
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, positive=True)
else:
noise_pred = self._infer_cond_uncond(inputs, positive=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else:
# ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, positive=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, positive=False)
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
# ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, positive=True)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
......@@ -370,24 +359,62 @@ class WanModel:
self.post_weight.to_cpu()
@torch.no_grad()
def infer_with_cfg_parallel(self, inputs):
assert self.config["enable_cfg"], "enable_cfg must be True"
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, f"cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
else:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
def _infer_cond_uncond(self, inputs, positive=True):
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=positive)
if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)
if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)
noise_pred = self.post_infer.infer(self.post_weight, x, pre_infer_out)[0]
if self.clean_cuda_cache:
del x, pre_infer_out
torch.cuda.empty_cache()
return noise_pred
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
embed, x, embed0 = pre_infer_out.embed, pre_infer_out.x, pre_infer_out.embed0
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
# 使用 F.pad 填充第一维
x = F.pad(x, (0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"].startswith("wan2.2"):
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
if padding_size > 0:
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out.x = x
pre_infer_out.embed = embed
pre_infer_out.embed0 = embed0
return pre_infer_out
@torch.no_grad()
def _seq_parallel_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
# 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_x, x, group=self.seq_p_group)
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0)
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
return combined_output # 返回合并后的输出
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
class WanPostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.register_parameter(
"norm",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
......@@ -26,6 +26,11 @@ class WanTransformerWeights(WeightModule):
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks)
# post blocks weights
self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]())
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment