"src/libtorchaudio/sox/effects.cpp" did not exist on "a051985f980a36e761e9653d876bfa81124d4516"
Commit 4f71a2b0 authored by mashun1's avatar mashun1
Browse files

wan2.1

parents
Pipeline #2434 failed with stages
in 0 seconds
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Wan2.1
## 论文
`Tech Report`
* https://wanxai.com/
## 模型结构
模型采用主流`Latent Diffusion`架构,包含用于数据压缩/恢复的`3D VAE`,去噪模块`DiT`,文本使用`T5`编码器处理。
![alt text](readme_imgs/arch.png)
## 算法原理
采用Flow matching算法。
![alt text](readme_imgs/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
docker run --shm-size 100g --network=host --name=wan --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
pip install "xfuser>=0.4" --no-deps torch
bash modified/fix.sh
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 100g --network=host --name=wan --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
pip install "xfuser>=0.4" --no-deps torch
bash modified/fix.sh
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK驱动:dtk24.04.3
python:python3.10
torch:2.3.0
torchvision:0.18.1
torchaudio:2.1.2
triton:2.1.0
vllm:0.6.2
flash-attn:2.6.1
deepspeed:0.14.2
apex:1.3.0
xformers:0.0.25
transformers:4.48.0
```
2、其他非特殊库直接按照requirements.txt安装
```
pip install -r requirements.txt
pip install "xfuser>=0.4" --no-deps torch
# 需要参考 modified/fix.sh中的命令修改相应位置的代码
```
## 数据集
## 训练
## 推理
### 文字-视频生成
1、单卡
<!-- # # 14B模型支持480/720P
# python generate.py --task t2v-14B --size 832*480 --ckpt_dir models/Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." -->
```bash
# 1.3B模型支持480P
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir models/Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
注意:若遇到显存不够的问题,可以尝试`--offload_model True``--t5_cpu`
2、多卡
```bash
# 1.3B
torchrun --nproc_per_node=4 generate.py --task t2v-1.3B --size 832*480 --ckpt_dir models/Wan2.1-T2V-1.3B --dit_fsdp --t5_fsdp --ulysses_size 4 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
# 14B
torchrun --nproc_per_node=4 generate.py --task t2v-14B --size 1280*720 --ckpt_dir models/Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 4 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
启用提示增强
```bash
<命令> --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch' --prompt_extend_model models/Qwen2.5-7B-Instruct
# example
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir models/Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch' --prompt_extend_model models/Qwen2.5-7B-Instruct
```
### 图像-视频生成
<!-- 1、单卡
```bash
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir models/Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
注意:若遇到显存不够的问题,可以尝试`--offload_model True``--t5_cpu` -->
```bash
torchrun --nproc_per_node=4 generate.py --task i2v-14B --size 1280*720 --ckpt_dir models/Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 4 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
启用提示增强
```bash
<命令> --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct
# example
torchrun --nproc_per_node=4 generate.py --task i2v-14B --size 1280*720 --ckpt_dir models/Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 4 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct
```
### 文本-图像生成
<!-- 1、单卡
```bash
python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir models/Wan2.1-T2V-14B --prompt '一个朴素端庄的美人'
```
注意:若遇到显存不够的问题,可以尝试`--offload_model True``--t5_cpu` -->
```bash
torchrun --nproc_per_node=4 generate.py --dit_fsdp --t5_fsdp --ulysses_size 4 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir models/Wan2.1-T2V-14B
```
启用提示增强
```bash
<命令> --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch' --prompt_extend_model models/Qwen2.5-7B-Instruct
# example
torchrun --nproc_per_node=4 generate.py --dit_fsdp --t5_fsdp --ulysses_size 4 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir models/Wan2.1-T2V-14B --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'ch' --prompt_extend_model models/Qwen2.5-7B-Instruct
```
### webui
```bash
python gradio/t2v_1.3B_singleGPU.py --ckpt_dir models/Wan2.1-T2V-1.3B --prompt_extend_method 'local_qwen' --prompt_extend_model models/Qwen2.5-7B-Instruct
```
## result
|model/task|t2v|i2v|t2i|
|:---:|:---:|:---:|:---:|
|T2V-14B|![](readme_imgs/t2v-14B.gif)||![](readme_imgs/t2i-14B.png)|
|T2V-1.3B|![](readme_imgs/t2v-1.3B.gif)||||
|I2V-14B-720P||![](readme_imgs/i2v-14B_720.gif)||
|I2V-14B-480P||![](readme_imgs/i2v-14B_480.gif)||
### 精度
## 应用场景
### 算法类别
`视频生成`
### 热点应用行业
`电商,教育,广媒`
## 预训练权重
下载后的模型放在 `models` 目录(自行创建)
|Models|下载链接|
|:---:|:---:|
|T2V-14B|[modelscope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/Wan-AI/Wan2.1-T2V-14B) |
|I2V-14B-720P|[modelscope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/Wan-AI/Wan2.1-I2V-14B-720P) |
|I2V-14B-480P|[modelscope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/Wan-AI/Wan2.1-I2V-14B-480P) |
|T2V-1.3B|[modelscope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/Wan-AI/Wan2.1-T2V-1.3B) |
|Qwen2.5-7B-Instruct|[modelscope](https://www.modelscope.cn/models/Qwen/Qwen2.5-7B-Instruct) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/qwen/Qwen2.5-7B-Instruct) |
|Qwen2.5-VL-7B-Instruct|[modelscope](https://www.modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct) \| [SCNet高速下载通道](http://113.200.138.88:18080/aimodels/qwen/qwen2.5-vl-7b-instruct) |
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch
## 参考资料
* https://github.com/Wan-Video/Wan2.1
This diff is collapsed.
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
warnings.filterwarnings('ignore')
import torch, random
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="ch",
choices=["ch", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=5.0,
help="Classifier free guidance scale.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_i2v_480P = None
wan_i2v_720P = None
# Button Func
def load_model(value):
global wan_i2v_480P, wan_i2v_720P
if value == '------':
print("No model loaded")
return '------'
if value == '720P':
if args.ckpt_dir_720p is None:
print("Please specify the checkpoint directory for 720P model")
return '------'
if wan_i2v_720P is not None:
pass
else:
del wan_i2v_480P
gc.collect()
wan_i2v_480P = None
print("load 14B-720P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_720P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_720p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '720P'
if value == '480P':
if args.ckpt_dir_480p is None:
print("Please specify the checkpoint directory for 480P model")
return '------'
if wan_i2v_480P is not None:
pass
else:
del wan_i2v_720P
gc.collect()
wan_i2v_720P = None
print("load 14B-480P i2v model...", end='', flush=True)
cfg = WAN_CONFIGS['i2v-14B']
wan_i2v_480P = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_480p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '480P'
def prompt_enc(prompt, img, tar_lang):
print('prompt extend...')
if img is None:
print('Please upload an image')
return prompt
global prompt_expander
prompt_output = prompt_expander(
prompt, image=img, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
# print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
if resolution == '------':
print(
'Please specify at least one resolution ckpt dir or specify the resolution'
)
return None
else:
if resolution == '720P':
global wan_i2v_720P
video = wan_i2v_720P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['720*1280'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
else:
global wan_i2v_480P
video = wan_i2v_480P.generate(
img2vid_prompt,
img2vid_image,
max_area=MAX_AREA_CONFIGS['480*832'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (I2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
resolution = gr.Dropdown(
label='Resolution',
choices=['------', '720P', '480P'],
value='------')
img2vid_image = gr.Image(
type="pil",
label="Upload Input Image",
elem_id="image_upload",
)
img2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_i2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
resolution.input(
fn=load_model, inputs=[resolution], outputs=[resolution])
run_p_button.click(
fn=prompt_enc,
inputs=[img2vid_prompt, img2vid_image, tar_lang],
outputs=[img2vid_prompt])
run_i2v_button.click(
fn=i2v_generation,
inputs=[
img2vid_prompt, img2vid_image, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--ckpt_dir_480p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=True)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=True, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image
# Global Var
prompt_expander = None
wan_t2i = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2i
# print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2i.generate(
txt2img_prompt,
size=(W, H),
frame_num=1,
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_image(
tensor=video.squeeze(1)[None],
save_file="example.png",
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.png"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2I-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2img_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2i_button = gr.Button("Generate Image")
with gr.Column():
result_gallery = gr.Image(
label='Generated Image', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2img_prompt, tar_lang],
outputs=[txt2img_prompt])
run_t2i_button.click(
fn=t2i_generation,
inputs=[
txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2i model...", end='', flush=True)
cfg = WAN_CONFIGS['t2i-14B']
wan_t2i = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=False)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-1.3B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'480*832',
'832*480',
'624*624',
'704*544',
'544*704',
],
value='480*832')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=6.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=20,
value=8.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 1.3B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
import wan
from wan.configs import WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_t2v = None
# Button Func
def prompt_enc(prompt, tar_lang):
global prompt_expander
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
shift_scale, seed, n_prompt):
global wan_t2v
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
W = int(resolution.split("*")[0])
H = int(resolution.split("*")[1])
video = wan_t2v.generate(
txt2vid_prompt,
size=(W, H),
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (T2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
txt2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["CH", "EN"],
label="Target language of prompt enhance",
value="CH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
resolution = gr.Dropdown(
label='Resolution(Width*Height)',
choices=[
'720*1280', '1280*720', '960*960', '1088*832',
'832*1088', '480*832', '832*480', '624*624',
'704*544', '544*704'
],
value='720*1280')
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=10,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_t2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
run_p_button.click(
fn=prompt_enc,
inputs=[txt2vid_prompt, tar_lang],
outputs=[txt2vid_prompt])
run_t2v_button.click(
fn=t2v_generation,
inputs=[
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir",
type=str,
default="cache",
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=False)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=False, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
print("Step2: Init 14B t2v model...", end='', flush=True)
cfg = WAN_CONFIGS['t2v-14B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
icon.png

70.5 KB

# 模型编码
modelCode=1437
# 模型名称
modelName=Wan2.1_pytorch
# 模型描述
modelDescription=阿里开源的超高质量视频生成模型
# 应用场景
appScenario=推理,视频生成,电商,教育,广媒
# 框架类型
frameType=pytorch
import os
import torch
import torch.distributed as dist
from packaging import version
from dataclasses import dataclass, fields
from torch import distributed as dist
from xfuser.logger import init_logger
import xfuser.envs as envs
# from xfuser.envs import CUDA_VERSION, TORCH_VERSION, PACKAGES_CHECKER
from xfuser.envs import TORCH_VERSION, PACKAGES_CHECKER
logger = init_logger(__name__)
from typing import Union, Optional, List
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]
def check_packages():
import diffusers
if not version.parse(diffusers.__version__) > version.parse("0.30.2"):
raise RuntimeError(
"This project requires diffusers version > 0.30.2. Currently, you can not install a correct version of diffusers by pip install."
"Please install it from source code!"
)
def check_env():
# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
#if CUDA_VERSION < version.parse("11.3"):
# raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
if TORCH_VERSION < version.parse("2.2.0"):
# https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
raise RuntimeError(
"CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
"If it is not released yet, please install nightly built PyTorch "
"with `pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cu121`"
)
@dataclass
class ModelConfig:
model: str
download_dir: Optional[str] = None
trust_remote_code: bool = False
@dataclass
class RuntimeConfig:
warmup_steps: int = 1
dtype: torch.dtype = torch.float16
use_cuda_graph: bool = False
use_parallel_vae: bool = False
use_profiler: bool = False
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
def __post_init__(self):
check_packages()
if self.use_cuda_graph:
check_env()
@dataclass
class FastAttnConfig:
use_fast_attn: bool = False
n_step: int = 20
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
def __post_init__(self):
assert self.n_calib > 0, "n_calib must be greater than 0"
assert self.threshold > 0.0, "threshold must be greater than 0"
@dataclass
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1
def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
# set classifier_free_guidance_degree parallel for split batch
if self.use_cfg_parallel:
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= self.world_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
)
assert (
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1
def __post_init__(self):
if self.ulysses_degree is None:
self.ulysses_degree = 1
logger.info(
f"Ulysses degree not set, " f"using default value {self.ulysses_degree}"
)
if self.ring_degree is None:
self.ring_degree = 1
logger.info(
f"Ring degree not set, " f"using default value {self.ring_degree}"
)
self.sp_degree = self.ulysses_degree * self.ring_degree
if not HAS_LONG_CTX_ATTN and self.sp_degree > 1:
raise ImportError(
f"Sequence Parallel kit 'yunchang' not found but "
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
@dataclass
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1
def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"
@dataclass
class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1
def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
logger.info(
f"Pipeline patch number not set, "
f"using default value {self.pp_degree}"
)
if self.attn_layer_num_for_pp is not None:
logger.info(
f"attn_layer_num_for_pp set, splitting attention layers"
f"to {self.attn_layer_num_for_pp}"
)
assert len(self.attn_layer_num_for_pp) == self.pp_degree, (
"attn_layer_num_for_pp must have the same "
"length as pp_degree if not None"
)
if self.pp_degree == 1 and self.num_pipeline_patch > 1:
logger.warning(
f"Pipefusion degree is 1, pipeline will not be used,"
f"num_pipeline_patch will be ignored"
)
self.num_pipeline_patch = 1
@dataclass
class ParallelConfig:
dp_config: DataParallelConfig
sp_config: SequenceParallelConfig
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
assert self.dp_config is not None, "dp_config must be set"
assert self.sp_config is not None, "sp_config must be set"
assert self.pp_config is not None, "pp_config must be set"
parallel_world_size = (
self.dp_config.dp_degree
* self.dp_config.cfg_degree
* self.sp_config.sp_degree
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
assert (
world_size % self.pp_config.pp_degree == 0
), "world_size must be divisible by pp_degree"
assert (
world_size % self.sp_config.sp_degree == 0
), "world_size must be divisible by sp_degree"
assert (
world_size % self.tp_config.tp_degree == 0
), "world_size must be divisible by tp_degree"
self.dp_degree = self.dp_config.dp_degree
self.cfg_degree = self.dp_config.cfg_degree
self.sp_degree = self.sp_config.sp_degree
self.pp_degree = self.pp_config.pp_degree
self.tp_degree = self.tp_config.tp_degree
self.ulysses_degree = self.sp_config.ulysses_degree
self.ring_degree = self.sp_config.ring_degree
@dataclass(frozen=True)
class EngineConfig:
model_config: ModelConfig
runtime_config: RuntimeConfig
parallel_config: ParallelConfig
fast_attn_config: FastAttnConfig
def __post_init__(self):
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs."""
return dict((field.name, getattr(self, field.name)) for field in fields(self))
@dataclass
class InputConfig:
height: int = 1024
width: int = 1024
num_frames: int = 49
use_resolution_binning: bool = (True,)
batch_size: Optional[int] = None
img_file_path: Optional[str] = None
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
num_inference_steps: int = 20
max_sequence_length: int = 256
seed: int = 42
output_type: str = "pil"
def __post_init__(self):
if isinstance(self.prompt, list):
assert (
len(self.prompt) == len(self.negative_prompt)
or len(self.negative_prompt) == 0
), "prompts and negative_prompts must have the same quantities"
self.batch_size = self.batch_size or len(self.prompt)
else:
self.batch_size = self.batch_size or 1
assert self.output_type in [
"pil",
"latent",
"pt",
], "output_pil must be either 'pil' or 'latent'"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment