Commit 9067043e authored by helloyongyang's avatar helloyongyang
Browse files

update wan2.2 moe parallel

parent dd958c79
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
...@@ -7,14 +7,12 @@ import os ...@@ -7,14 +7,12 @@ import os
import safetensors import safetensors
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
...@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from loguru import logger
class WanAudioModel(WanModel): class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
......
...@@ -55,7 +55,7 @@ class WanTransformerDistInfer(WanTransformerInfer): ...@@ -55,7 +55,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group) world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0]
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
...@@ -75,7 +75,7 @@ class WanTransformerDistInfer(WanTransformerInfer): ...@@ -75,7 +75,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group) world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0]
valid_token_length = f * h * w valid_token_length = f * h * w
f = f + 1 f = f + 1
seq_len = f * h * w seq_len = f * h * w
......
...@@ -11,6 +11,7 @@ from lightx2v.utils.envs import * ...@@ -11,6 +11,7 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
......
...@@ -113,7 +113,7 @@ class DefaultRunner(BaseRunner): ...@@ -113,7 +113,7 @@ class DefaultRunner(BaseRunner):
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"): with ProfilingContext4Debug("infer_main"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"): with ProfilingContext4Debug("step_post"):
...@@ -239,6 +239,7 @@ class DefaultRunner(BaseRunner): ...@@ -239,6 +239,7 @@ class DefaultRunner(BaseRunner):
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
else: else:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=self.config.save_video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
logger.info(f"Video saved successfully.")
del latents, generator del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=2 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=8 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_cfg_ulysses.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_cfg_ulysses.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
torchrun --nproc_per_node=4 -m lightx2v.infer \
--model_cls wan2.2_moe \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist_infer/wan22_moe_i2v_ulysses.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_parallel_ulysses.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment