Commit 9067043e authored by helloyongyang's avatar helloyongyang
Browse files

update wan2.2 moe parallel

parent dd958c79
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
......@@ -7,14 +7,12 @@ import os
import safetensors
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......
......@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from loguru import logger
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
......
......@@ -55,7 +55,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
......@@ -75,7 +75,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
......
......@@ -11,6 +11,7 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
......
......@@ -113,7 +113,7 @@ class DefaultRunner(BaseRunner):
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
......@@ -239,6 +239,7 @@ class DefaultRunner(BaseRunner):
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
else:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
logger.info(f"Video saved successfully.")
del latents, generator
torch.cuda.empty_cache()
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=2 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=8 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg_ulysses.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=4 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_ulysses.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_ulysses.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment