"mmdet3d/evaluation/functional/indoor_eval.py" did not exist on "e5b1ec11199b18b64a2d75d2e79ace41ec788d81"
Commit 9b03e5c1 authored by helloyongyang's avatar helloyongyang
Browse files

add some docs and update configs and readme

parent fe13f4db
# 准备环境
我们推荐使用docker环境,这是lightx2v的[dockerhub](https://hub.docker.com/r/lightx2v/lightx2v/tags),请选择一个最新日期的tag,比如25042502
```shell
docker pull lightx2v/lightx2v:25042502
docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --entrypoint /bin/bash [镜像id]
```
如果你想使用conda自己搭建环境,可以参考如下步骤:
```shell
# 下载github代码
git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
git submodule update --init --recursive
conda create -n lightx2v python=3.11 && conda activate lightx2v
pip install -r requirements.txt
# 单独重新安装transformers,避免pip的冲突检查
# 混元模型需要在4.45.2版本的transformers下运行,如果不需要跑混元模型,可以忽略
pip install transformers==4.45.2
# 安装 flash-attention 2
cd lightx2v/3rd/flash-attention && pip install --no-cache-dir -v -e .
# 安装 flash-attention 3, 用于 hopper 显卡
cd lightx2v/3rd/flash-attention/hopper && pip install --no-cache-dir -v -e .
```
# 推理
```shell
# 修改脚本中的路径
bash scripts/run_wan_t2v.sh
```
除了脚本中已有的输入参数,`--config_json`指向的`${lightx2v_path}/configs/wan_t2v.json`中也会存在一些必要的参数,可以根据需要,自行修改。
# 如何启动服务
lightx2v提供了异步服务功能,代码入口处在[这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
### 启动服务
```shell
# 修改脚本中的路径
bash scripts/start_server.sh
```
其中的`--port 8000`表示服务绑定在本机的`8000`端口上,可以自行修改
### 客户端发送请求
```shell
python scripts/post.py
```
服务的接口是:`/v1/local/video/generate`
`scripts/post.py`中的`message`参数如下:
```python
message = {
"task_id": generate_task_id(),
"task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
}
```
1. `prompt`, `negative_prompt`, `image_path`是一些基础的视频生成的输入,`image_path`可以为空字符,表示不需要图片输入
2. `save_video_path`表示服务端生成的视频的路径,相对路径是相对服务端的启动路径,建议根据你自己的环境,设置一个绝对路径。
3. `task_id`表示该任务的id,格式是一个字符串。可以自定义个字符串,也可以调用`generate_task_id()`函数生成一个随机的字符串。任务的id用来区分不同的视频生成任务。
4. `task_id_must_unique`表示是否要求每个`task_id`是独一无二的,即不能发有重复的`task_id`。如果是`False`,就没有这个强制要求,此时如果发送了重复的`task_id`,服务端的`task`记录将会被相同`task_id`的较新的`task`覆盖掉。如果不需要记录所有的`task`以用于查询,那这里就可以设置成`False`
### 客户端获取服务端的状态
```shell
python scripts/check_status.py
```
其中服务的接口有:
1. `/v1/local/video/generate/service_status`用于检查服务的状态,可以返回得到服务是`busy`还是`idle`,只有在`idle`状态,该服务才会接收新的请求。
2. `/v1/local/video/generate/get_all_tasks`用于获取服务端接收到的且已完成的所有的任务。
3. `/v1/local/video/generate/task_status`用于获取指定`task_id`的状态,可以返回得到该任务是`processing`还是`completed`
### 客户端随时终止服务端当前的任务
```shell
python scripts/stop_running_task.py
```
服务的接口是:`/v1/local/video/generate/stop_running_task`
终止了任务之后,服务端并不会退出服务,而是回到等待接收新请求的状态。
# 量化
lightx2v支持对linear进行量化推理,支持w8a8和fp8的矩阵乘法。
### 运行量化推理
```shell
# 修改脚本中的路径
bash scripts/run_wan_t2v_save_quant.sh
```
脚本中,有两个执行命令:
#### save quantization weight
`RUNNING_FLAG`环境变量设置成`save_naive_quant``--config_json`指向到该`json`文件: `${lightx2v_path}/configs/wan_t2v_save_quant.json`,其中`quant_model_path`会保存下量化的模型的路径
#### load quantization weight and inference
`RUNNING_FLAG`环境变量设置成`infer``--config_json`指向到第一步中的`json`文件
### 启动量化服务
在存好量化权重之后,和上一步加载步骤一样,将`RUNNING_FLAG`环境变量设置成`infer``--config_json`指向到第一步中的`json`文件
比如,将`scripts/start_server.sh`脚本进行如下改动:
```shell
export RUNNING_FLAG=infer
python -m lightx2v.api_server \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--port 8000
```
...@@ -28,8 +28,8 @@ class HunyuanModel: ...@@ -28,8 +28,8 @@ class HunyuanModel:
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant": if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.naive_quant_path) self.save_weights(self.config.quant_model_path)
sys.exit(0) sys.exit(0)
self._init_infer() self._init_infer()
...@@ -66,9 +66,9 @@ class HunyuanModel: ...@@ -66,9 +66,9 @@ class HunyuanModel:
return weight_dict return weight_dict
def _load_ckpt_quant_model(self): def _load_ckpt_quant_model(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
logger.info(f"Loading quant model from {self.config.naive_quant_path}") logger.info(f"Loading quant model from {self.config.quant_model_path}")
quant_weights_path = os.path.join(self.config.naive_quant_path, "quant_weights.pth") quant_weights_path = os.path.join(self.config.quant_model_path, "quant_weights.pth")
weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True) weight_dict = torch.load(quant_weights_path, map_location=self.device, weights_only=True)
return weight_dict return weight_dict
......
...@@ -35,8 +35,8 @@ class WanModel: ...@@ -35,8 +35,8 @@ class WanModel:
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
if GET_RUNNING_FLAG() == "save_naive_quant": if GET_RUNNING_FLAG() == "save_naive_quant":
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
self.save_weights(self.config.naive_quant_path) self.save_weights(self.config.quant_model_path)
sys.exit(0) sys.exit(0)
self._init_infer() self._init_infer()
...@@ -85,8 +85,8 @@ class WanModel: ...@@ -85,8 +85,8 @@ class WanModel:
return weight_dict return weight_dict
def _load_quant_ckpt(self): def _load_quant_ckpt(self):
assert self.config.get("naive_quant_path") is not None, "naive_quant_path is None" assert self.config.get("quant_model_path") is not None, "quant_model_path is None"
ckpt_path = self.config.naive_quant_path ckpt_path = self.config.quant_model_path
logger.info(f"Loading quant model from {ckpt_path}") logger.info(f"Loading quant model from {ckpt_path}")
quant_pth_file = os.path.join(ckpt_path, "quant_weights.pth") quant_pth_file = os.path.join(ckpt_path, "quant_weights.pth")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment