Commit 4f71a2b0 authored by mashun1's avatar mashun1
Browse files

wan2.1

parents
Pipeline #2434 failed with stages
in 0 seconds
import os
import torch
import diffusers
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from packaging import version
from xfuser.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
MASTER_ADDR: str = ""
MASTER_PORT: Optional[int] = None
CUDA_HOME: Optional[str] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
XDIT_LOGGING_LEVEL: str = "INFO"
CUDA_VERSION: version.Version
TORCH_VERSION: version.Version
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
"MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
# used in distributed environment to manually set the communication port
"MASTER_PORT": lambda: (
int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# this is used for configuring the default logging level
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}
variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
# "CUDA_VERSION": lambda: version.parse(torch.version.cuda),
"CUDA_VERSION": "gfx928",
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
}
class PackagesEnvChecker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.packages_info = {
"has_flash_attn": self.check_flash_attn(),
"has_long_ctx_attn": self.check_long_ctx_attn(),
"diffusers_version": self.check_diffusers_version(),
}
def check_flash_attn(self):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
return False
else:
from flash_attn import flash_attn_func
from flash_attn import __version__
if __version__ < "2.6.0":
raise ImportError(f"install flash_attn >= 2.6.0")
return True
except ImportError:
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False
def check_long_ctx_attn(self):
try:
from yunchang import (
set_seq_parallel_pg,
ring_flash_attn_func,
UlyssesAttention,
LongContextAttention,
LongContextAttentionQKVPacked,
)
return True
except ImportError:
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False
def check_diffusers_version(self):
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)
def get_packages_info(self):
return self.packages_info
PACKAGES_CHECKER = PackagesEnvChecker()
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
if name in variables:
return variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
#!/bin/bash
cp modified/config.py /usr/local/lib/python3.10/site-packages/xfuser/config/
cp modified/envs.py /usr/local/lib/python3.10/site-packages/xfuser/
\ No newline at end of file
This image diff could not be displayed because it is too large. You can view the blob instead.
This image diff could not be displayed because it is too large. You can view the blob instead.
# torch>=2.4.0
# torchvision>=0.19.0
opencv-python>=4.9.0.80
diffusers>=0.31.0
transformers>=4.49.0
tokenizers>=0.20.3
accelerate>=1.1.1
tqdm
imageio
easydict
ftfy
dashscope
imageio-ffmpeg
# flash_attn
gradio>=5.0.0
numpy==1.24.4
yunchang
DistVAE
\ No newline at end of file
Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
```bash
bash ./test.sh <local model dir> <gpu number>
```
#!/bin/bash
if [ "$#" -eq 2 ]; then
MODEL_DIR=$(realpath "$1")
GPUS=$2
else
echo "Usage: $0 <local model dir> <gpu number>"
exit 1
fi
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT" || exit 1
PY_FILE=./generate.py
function t2v_1_3B() {
T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function t2v_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function t2i_14B() {
T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
}
function i2v_14B_480p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
if [ -n "${DASH_API_KEY+x}" ]; then
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
else
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
fi
}
function i2v_14B_720p() {
I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
# 1-GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
# Multiple GPU Test
echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
from .wan_i2v_14B import i2v_14B
from .wan_t2v_1_3B import t2v_1_3B
from .wan_t2v_14B import t2v_14B
# the config of t2i_14B is the same as t2v_14B
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'
WAN_CONFIGS = {
't2v-14B': t2v_14B,
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
}
SIZE_CONFIGS = {
'720*1280': (720, 1280),
'1280*720': (1280, 720),
'480*832': (480, 832),
'832*480': (832, 480),
'1024*1024': (1024, 1024),
}
MAX_AREA_CONFIGS = {
'720*1280': 720 * 1280,
'1280*720': 1280 * 720,
'480*832': 480 * 832,
'832*480': 832 * 480,
}
SUPPORTED_SIZES = {
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2v-1.3B': ('480*832', '832*480'),
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment