Commit 5c241f86 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support run load memory machine, fix some bugs and reconstruct quantizaton. (#61)



* reconstruct quantization and fix memory leak bug.

* Support lazy load inference.

* reconstruct quantization

* Fix hunyuan bugs

* deleted tmp file

---------
Co-authored-by: default avatarroot <root@pt-c0b333b3a1834e81a0d4d5f412c6ffa1-worker-0.pt-c0b333b3a1834e81a0d4d5f412c6ffa1.ns-devsft-3460edd0.svc.cluster.local>
Co-authored-by: default avatargushiqiao <gushqiaio@sensetime.com>
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent b7d2d43f
......@@ -14,5 +14,5 @@
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"quant_method": "smoothquant"
},
"quant_model_path": "/path/to/int8_model"
"dit_quantized_ckpt": "/path/to/dit_int8"
}
{
"infer_steps": 20,
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"attention_type": "flash_attn3",
"seed": 42,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"quant_model_path": "./hy_t2v_quant_model"
}
......@@ -3,15 +3,11 @@
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
}
"offload_granularity": "block"
}
......@@ -10,12 +10,11 @@
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_int8",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F",
"weight_auto_quant": true
"weight_auto_quant": false
},
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"text_encoder_offload_granularity": "block"
"use_tiling_vae": true
}
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_int8",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-int8.pth",
"t5_quant_scheme": "int8",
"clip_quantized": true,
"clip_quantized_ckpt": "/path/to/clip_int8.pth",
"clip_quant_scheme": "int8",
"use_tiling_vae": true,
"tiny_vae": true,
"tiny_vae_path": "/path/to/taew2_1.pth",
"lazy_load": true
}
......@@ -17,5 +17,5 @@
},
"tiny_vae": true,
"tiny_vae_path": "/mnt/afs_2/gushiqiao/x2v_models/taew2_1.pth",
"text_encoder_offload_granularity": "block"
"t5_offload_granularity": "block"
}
......@@ -4,8 +4,9 @@
"i2v_resolution": "720p",
"attention_type": "flash_attn3",
"seed": 0,
"dit_quantized_ckpt": "/mtc/gushiqiao/llmc_workspace/x2v_models/hunyuan/hunyuan_i2v_int8.pth",
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
"quant_model_path": "./hy_i2v_quant_model"
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false
}
}
......@@ -3,14 +3,14 @@
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"quant_model_path": "./wan_i2v_quant_model"
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true
}
}
{
"infer_steps": 50,
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"attention_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": false,
"dit_quantized_ckpt": "/path/to/int8/model",
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"quant_model_path": "./wan_t2v_quant_model"
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false
}
}
{
"infer_steps": 9,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"num_fragments": 3,
"num_frames": 21,
"num_frame_per_block": 3,
"num_blocks": 7,
"frame_seq_length": 1560,
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74]
}
# Quantization
lightx2v supports quantized inference for linear layers, supporting w8a8-int8 and w8a8-fp8 matrix multiplication.
lightx2v supports quantized inference for linear layers in **Dit**, enabling `w8a8-int8` and `w8a8-fp8` matrix multiplication.
## Generating Quantized Models
### Run Quantized Inference
### Automatic Quantization
```shell
# Modify the path in the script
bash scripts/run_wan_t2v_save_quant.sh
```
lightx2v supports automatic weight quantization during inference. Refer to the [configuration file](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_auto.json).
**Key configuration**:
Set `"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}`.
- `mm_type`: Specifies the quantized operator
- `weight_auto_quant: true`: Enables automatic model quantization
## Quantized Inference
There are two execution commands in the script:
### Offline Quantization
#### Save Quantization Weights
lightx2v also supports direct loading of pre-quantized weights. For offline model quantization, refer to the [documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme.md).
Configure the [quantization file](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_offline.json):
1. Set `dit_quantized_ckpt` to the converted weight path
2. Set `weight_auto_quant` to `false` in `mm_type`
### Automatic Quantization
```shell
bash scripts/run_wan_i2v_quant_auto.sh
```
Set the `RUNNING_FLAG` environment variable to `save_naive_quant`, and set `--config_json` to the corresponding `json` file: `${lightx2v_path}/configs/wan_t2v_save_quant.json`. In this file, `quant_model_path` specifies the path to save the quantized model.
### Offline Quantization
```shell
bash scripts/run_wan_i2v_quant_offline.sh
#### Load Quantization Weights and Inference
```
Set the `RUNNING_FLAG` environment variable to `infer`, and set `--config_json` to the `json` file from the previous step.
## Launching Quantization Service
### Start Quantization Service
After saving the quantized weights, as in the previous loading step, set the `RUNNING_FLAG` environment variable to `infer`, and set `--config_json` to the `json` file from the first step.
After offline quantization, point `--config_json` to the offline quantization JSON file.
For example, modify the `scripts/start_server.sh` script as follows:
Example modification in `scripts/start_server.sh`:
```shell
export RUNNING_FLAG=infer
......@@ -33,6 +46,10 @@ python -m lightx2v.api_server \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_offline.json \
--port 8000
```
## Advanced Quantization Features
Refer to the quantization tool [LLMC documentation](https://github.com/ModelTC/llmc/blob/main/docs/en/source/backend/lightx2v.md) for details.
# 量化
lightx2v支持对linear进行量化推理,支持w8a8-int8w8a8-fp8的矩阵乘法。
lightx2v支持对`Dit`中的线性层进行量化推理,支持`w8a8-int8``w8a8-fp8`的矩阵乘法。
### 运行量化推理
## 生产量化模型
```shell
# 修改脚本中的路径
bash scripts/run_wan_t2v_save_quant.sh
```
### 自动量化
lightx2v支持推理时自动对模型权重进行量化,具体可参考[配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_auto.json)
值得注意的是,需要将配置文件的**mm_config**进行设置:**"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm","weight_auto_quant": true }****mm_type**代表希望使用的量化算子,**weight_auto_quant:true**代表自动转量化模型。
脚本中,有两个执行命令:
#### save quantization weight
## 量化推理
`RUNNING_FLAG`环境变量设置成`save_naive_quant``--config_json`指向到该`json`文件: `${lightx2v_path}/configs/wan_t2v_save_quant.json`,其中`quant_model_path`会保存下量化的模型的路径
### 离线量化
#### load quantization weight and inference
lightx2v同时支持直接加载量化好的权重进行推理,对模型进行离线量化可参考[文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
将转换的权重路径,写到[配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_offline.json)中的`dit_quantized_ckpt`中,同时`mm_type**中的**weight_auto_quant`置为`false`即可。
`RUNNING_FLAG`环境变量设置成`infer``--config_json`指向到第一步中的`json`文件
### 启动量化服务
### 自动量化
```shell
bash scripts/run_wan_i2v_quant_auto.sh
```
### 离线量化
```shell
bash scripts/run_wan_i2v_quant_offline.sh
```
## 启动量化服务
在存好量化权重之后,和上一步加载步骤一样,将`RUNNING_FLAG`环境变量设置成`infer``--config_json`指向到第一步中`json`文件
建议离线转好量化权重之后,`--config_json`指向到离线量化`json`文件
比如,将`scripts/start_server.sh`脚本进行如下改动:
......@@ -33,6 +41,10 @@ python -m lightx2v.api_server \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_save_quant.json \
--config_json ${lightx2v_path}/configs/quantization/wan_i2v_quant_offline.json \
--port 8000
```
## 高阶量化功能
具体可参考量化工具[LLMC的文档](https://github.com/ModelTC/llmc/blob/main/docs/zh_cn/source/backend/lightx2v.md)
import os
import re
import glob
import json
import argparse
import torch
from safetensors import safe_open, torch as st
from loguru import logger
from tqdm import tqdm
def get_key_mapping_rules(direction, model_type):
if model_type == "wan":
unified_rules = [
{"forward": (r"^head\.head$", "proj_out"), "backward": (r"^proj_out$", "head.head")},
{"forward": (r"^head\.modulation$", "scale_shift_table"), "backward": (r"^scale_shift_table$", "head.modulation")},
{"forward": (r"^text_embedding\.0\.", "condition_embedder.text_embedder.linear_1."), "backward": (r"^condition_embedder.text_embedder.linear_1\.", "text_embedding.0.")},
{"forward": (r"^text_embedding\.2\.", "condition_embedder.text_embedder.linear_2."), "backward": (r"^condition_embedder.text_embedder.linear_2\.", "text_embedding.2.")},
{"forward": (r"^time_embedding\.0\.", "condition_embedder.time_embedder.linear_1."), "backward": (r"^condition_embedder.time_embedder.linear_1\.", "time_embedding.0.")},
{"forward": (r"^time_embedding\.2\.", "condition_embedder.time_embedder.linear_2."), "backward": (r"^condition_embedder.time_embedder.linear_2\.", "time_embedding.2.")},
{"forward": (r"^time_projection\.1\.", "condition_embedder.time_proj."), "backward": (r"^condition_embedder.time_proj\.", "time_projection.1.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.q\.", r"blocks.\1.attn1.to_q."), "backward": (r"blocks\.(\d+)\.attn1\.to_q\.", r"blocks.\1.self_attn.q.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.k\.", r"blocks.\1.attn1.to_k."), "backward": (r"blocks\.(\d+)\.attn1\.to_k\.", r"blocks.\1.self_attn.k.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.v\.", r"blocks.\1.attn1.to_v."), "backward": (r"blocks\.(\d+)\.attn1\.to_v\.", r"blocks.\1.self_attn.v.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.o\.", r"blocks.\1.attn1.to_out.0."), "backward": (r"blocks\.(\d+)\.attn1\.to_out\.0\.", r"blocks.\1.self_attn.o.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.q\.", r"blocks.\1.attn2.to_q."), "backward": (r"blocks\.(\d+)\.attn2\.to_q\.", r"blocks.\1.cross_attn.q.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k\.", r"blocks.\1.attn2.to_k."), "backward": (r"blocks\.(\d+)\.attn2\.to_k\.", r"blocks.\1.cross_attn.k.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v\.", r"blocks.\1.attn2.to_v."), "backward": (r"blocks\.(\d+)\.attn2\.to_v\.", r"blocks.\1.cross_attn.v.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.o\.", r"blocks.\1.attn2.to_out.0."), "backward": (r"blocks\.(\d+)\.attn2\.to_out\.0\.", r"blocks.\1.cross_attn.o.")},
{"forward": (r"blocks\.(\d+)\.norm3\.", r"blocks.\1.norm2."), "backward": (r"blocks\.(\d+)\.norm2\.", r"blocks.\1.norm3.")},
{"forward": (r"blocks\.(\d+)\.ffn\.0\.", r"blocks.\1.ffn.net.0.proj."), "backward": (r"blocks\.(\d+)\.ffn\.net\.0\.proj\.", r"blocks.\1.ffn.0.")},
{"forward": (r"blocks\.(\d+)\.ffn\.2\.", r"blocks.\1.ffn.net.2."), "backward": (r"blocks\.(\d+)\.ffn\.net\.2\.", r"blocks.\1.ffn.2.")},
{"forward": (r"blocks\.(\d+)\.modulation\.", r"blocks.\1.scale_shift_table."), "backward": (r"blocks\.(\d+)\.scale_shift_table(?=\.|$)", r"blocks.\1.modulation")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.k_img\.", r"blocks.\1.attn2.add_k_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_k_proj\.", r"blocks.\1.cross_attn.k_img.")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.v_img\.", r"blocks.\1.attn2.add_v_proj."), "backward": (r"blocks\.(\d+)\.attn2\.add_v_proj\.", r"blocks.\1.cross_attn.v_img.")},
{
"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k_img\.weight", r"blocks.\1.attn2.norm_added_k.weight"),
"backward": (r"blocks\.(\d+)\.attn2\.norm_added_k\.weight", r"blocks.\1.cross_attn.norm_k_img.weight"),
},
{"forward": (r"img_emb\.proj\.0\.", r"condition_embedder.image_embedder.norm1."), "backward": (r"condition_embedder\.image_embedder\.norm1\.", r"img_emb.proj.0.")},
{"forward": (r"img_emb\.proj\.1\.", r"condition_embedder.image_embedder.ff.net.0.proj."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.0\.proj\.", r"img_emb.proj.1.")},
{"forward": (r"img_emb\.proj\.3\.", r"condition_embedder.image_embedder.ff.net.2."), "backward": (r"condition_embedder\.image_embedder\.ff\.net\.2\.", r"img_emb.proj.3.")},
{"forward": (r"img_emb\.proj\.4\.", r"condition_embedder.image_embedder.norm2."), "backward": (r"condition_embedder\.image_embedder\.norm2\.", r"img_emb.proj.4.")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_q\.weight", r"blocks.\1.attn1.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_q\.weight", r"blocks.\1.self_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.self_attn\.norm_k\.weight", r"blocks.\1.attn1.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn1\.norm_k\.weight", r"blocks.\1.self_attn.norm_k.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_q\.weight", r"blocks.\1.attn2.norm_q.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_q\.weight", r"blocks.\1.cross_attn.norm_q.weight")},
{"forward": (r"blocks\.(\d+)\.cross_attn\.norm_k\.weight", r"blocks.\1.attn2.norm_k.weight"), "backward": (r"blocks\.(\d+)\.attn2\.norm_k\.weight", r"blocks.\1.cross_attn.norm_k.weight")},
# head projection mapping
{"forward": (r"^head\.head\.", "proj_out."), "backward": (r"^proj_out\.", "head.head.")},
]
if direction == "forward":
return [rule["forward"] for rule in unified_rules]
elif direction == "backward":
return [rule["backward"] for rule in unified_rules]
else:
raise ValueError(f"Invalid direction: {direction}")
else:
raise ValueError(f"Unsupported model type: {model_type}")
def convert_weights(args):
if os.path.isdir(args.source):
src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True)
elif args.source.endswith((".pth", ".safetensors", "pt")):
src_files = [args.source]
else:
raise ValueError("Invalid input path")
merged_weights = {}
logger.info(f"Processing source files: {src_files}")
for file_path in tqdm(src_files, desc="Loading weights"):
logger.info(f"Loading weights from: {file_path}")
if file_path.endswith(".pt") or file_path.endswith(".pth"):
weights = torch.load(file_path, map_location="cpu", weights_only=True)
elif file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
weights = {k: f.get_tensor(k) for k in f.keys()}
duplicate_keys = set(weights.keys()) & set(merged_weights.keys())
if duplicate_keys:
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights)
rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {}
logger.info("Converting keys...")
for key in tqdm(merged_weights.keys(), desc="Converting keys"):
new_key = key
for pattern, replacement in rules:
new_key = re.sub(pattern, replacement, new_key)
converted_weights[new_key] = merged_weights[key]
os.makedirs(args.output, exist_ok=True)
base_name = os.path.splitext(os.path.basename(args.source))[0] if args.source.endswith((".pth", ".safetensors")) else "converted_model"
index = {"metadata": {"total_size": 0}, "weight_map": {}}
chunk_idx = 0
current_chunk = {}
for idx, (k, v) in tqdm(enumerate(converted_weights.items()), desc="Saving chunks"):
current_chunk[k] = v
if (idx + 1) % args.chunk_size == 0 and args.chunk_size > 0:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
current_chunk = {}
chunk_idx += 1
if current_chunk:
output_filename = f"{base_name}_part{chunk_idx}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving final chunk to: {output_path}")
st.save_file(current_chunk, output_path)
for key in current_chunk:
index["weight_map"][key] = output_filename
index["metadata"]["total_size"] += os.path.getsize(output_path)
# Save index file
index_path = os.path.join(args.output, "diffusion_pytorch_model.safetensors.index.json")
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}")
def main():
parser = argparse.ArgumentParser(description="Model weight format converter")
parser.add_argument("-s", "--source", required=True, help="Input path (file or directory)")
parser.add_argument("-o", "--output", required=True, help="Output directory path")
parser.add_argument("-d", "--direction", choices=["forward", "backward"], default="forward", help="Conversion direction: forward = 'lightx2v' -> 'Diffusers', backward = reverse")
parser.add_argument("-c", "--chunk-size", type=int, default=100, help="Chunk size for saving (only applies to forward), 0 = no chunking")
parser.add_argument("-t", "--model_type", choices=["wan"], default="wan", help="Model type")
args = parser.parse_args()
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
logger.info("Starting model weight conversion...")
convert_weights(args)
logger.info(f"Conversion completed! Files saved to: {args.output}")
if __name__ == "__main__":
main()
......@@ -24,13 +24,42 @@ class WeightModule:
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def calculate_size(self):
total_size = 0
for _, module in self._modules.items():
if hasattr(module, "_calculate_size"):
total_size += module._calculate_size()
for _, parameter in self._parameters.items():
if hasattr(parameter, "_calculate_size"):
total_size += parameter._calculate_size()
return total_size
def load_from_disk(self):
for _, module in self._modules.items():
if hasattr(module, "load_from_disk"):
module.load_from_disk()
for _, parameter in self._parameters.items():
if hasattr(parameter, "load_from_disk"):
parameter.load_from_disk()
def clear(self):
for _, module in self._modules.items():
if hasattr(module, "clear"):
module.clear()
for _, parameter in self._parameters.items():
if hasattr(parameter, "clear"):
parameter.clear()
def state_dict(self, destination=None):
if destination is None:
destination = {}
for name, param in self._parameters.items():
for _, param in self._parameters.items():
if param is not None:
param.state_dict(destination)
for name, module in self._modules.items():
for _, module in self._modules.items():
if module is not None:
module.state_dict(destination)
return destination
......@@ -58,6 +87,9 @@ class WeightModule:
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu()
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
......@@ -76,6 +108,9 @@ class WeightModule:
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda()
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()
......@@ -95,6 +130,9 @@ class WeightModule:
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cpu"):
m.to_cpu(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)
......@@ -113,6 +151,9 @@ class WeightModule:
for m in module[i]._modules.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
for m in module[i]._parameters.values():
if m is not None and hasattr(m, "to_cuda"):
m.to_cuda(non_blocking=True)
else:
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=True)
......
import torch
import threading
import queue
import time
from loguru import logger
from collections import OrderedDict
class WeightAsyncStreamManager(object):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
self.active_weights = [None for _ in range(3)]
self.active_weights = [None for _ in range(3)]
self.compute_stream = torch.cuda.Stream(priority=-1)
self.cpu_load_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = offload_ratio * blocks_num
self.offload_block_num = int(offload_ratio * blocks_num)
self.phases_num = phases_num
self.offload_phases_num = blocks_num * phases_num * offload_ratio
......@@ -47,3 +51,197 @@ class WeightAsyncStreamManager(object):
self.cpu_load_stream.synchronize()
self.cuda_load_stream.synchronize()
self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]
self.active_weights[2] = None
class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2):
super().__init__(blocks_num, offload_ratio, phases_num)
self.worker_stop_event = threading.Event()
self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
self.disk_task_queue = queue.PriorityQueue()
self.disk_workers = []
self.release_workers = []
self._start_disk_workers(num_disk_workers)
self.initial_prefetch_done = False
self.pending_tasks = {}
self.task_lock = threading.Lock()
self.last_used_time = {}
self.time_lock = threading.Lock()
def _start_disk_workers(self, num_workers):
for i in range(num_workers):
worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
worker.start()
self.disk_workers.append(worker)
def _disk_worker_loop(self):
while not self.worker_stop_event.is_set():
try:
_, task = self.disk_task_queue.get(timeout=0.5)
if task is None:
break
block_idx, phase_idx, phase = task
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
with self.task_lock:
if (block_idx, phase_idx) in self.pending_tasks:
del self.pending_tasks[(block_idx, phase_idx)]
except queue.Empty:
continue
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx)
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
continue
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
def _sync_prefetch_block(self, weights):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
for phase_idx in range(self.phases_num):
phase = weights.blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
block_idx += 1
def prefetch_weights_from_disk(self, weights):
if self.initial_prefetch_done:
return
self._sync_prefetch_block(weights)
self.initial_prefetch_done = True
def prefetch_phase(self, block_idx, phase_idx, blocks):
obj_key = (block_idx, phase_idx)
if not self.pin_memory_buffer.exists(obj_key):
is_loading = False
with self.task_lock:
if obj_key in self.pending_tasks:
is_loading = True
if is_loading:
start_time = time.time()
while not self.pin_memory_buffer.exists(obj_key):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
else:
logger.info("Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug.")
with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key)
phase.to_cuda_async()
self.active_weights[2] = (obj_key, phase)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
if self.active_weights[1] is not None:
old_key, old_phase = self.active_weights[1]
if self.pin_memory_buffer.exists(old_key):
old_phase.to_cpu_async()
self.pin_memory_buffer.pop(old_key)
def shutdown(self):
self.worker_stop_event.set()
while not self.disk_task_queue.empty():
try:
self.disk_task_queue.get_nowait()
except queue.Empty:
continue
for _ in self.disk_workers:
self.disk_task_queue.put((0, None))
for worker in self.disk_workers:
worker.join(timeout=5)
for worker in self.release_workers:
worker.join(timeout=5)
logger.info("All worker threads have been closed")
class MemoryBuffer:
def __init__(self, max_memory_bytes=8 * (1024**3)):
self.cache = OrderedDict()
self.max_mem = max_memory_bytes
self.used_mem = 0
self.phases_size_map = {}
self.lock = threading.Lock()
self.insertion_order = []
self.insertion_index = 0
def push(self, key, phase_obj):
with self.lock:
if key in self.cache:
return
_, phase_idx = key
if phase_idx not in self.phases_size_map:
self.phases_size_map[phase_idx] = phase_obj.calculate_size()
size = self.phases_size_map[phase_idx]
self.cache[key] = (size, phase_obj, self.insertion_index)
self.insertion_order.append((key, self.insertion_index))
self.insertion_index += 1
self.used_mem += size
def _remove_key(self, key):
if key in self.cache:
size, phase, idx = self.cache.pop(key)
try:
phase.clear()
except Exception as e:
logger.info(f"Error clearing phase: {e}")
self.used_mem -= size
self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]
def get(self, key, default=None):
with self.lock:
if key in self.cache:
size, phase, idx = self.cache[key]
return phase
return default
def exists(self, key):
with self.lock:
return key in self.cache
def pop(self, key):
with self.lock:
if key in self.cache:
self._remove_key(key)
return True
return False
def is_nearly_full(self):
with self.lock:
return self.used_mem >= self.max_mem * 0.9
def get_max_block_index(self):
with self.lock:
if not self.cache:
return -1
return max((key[0] + 1) % 40 for key in self.cache.keys())
......@@ -27,7 +27,7 @@ if torch.cuda.get_device_capability(0) == (8, 9):
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
print("sageattn not found, please install sageattention first")
sageattn = None, None
sageattn = None
else:
try:
from sageattention import sageattn
......
This diff is collapsed.
......@@ -4,15 +4,53 @@ from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, eps=1e-6):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
self.weight_name = weight_name
self.bias_name = bias_name
self.eps = eps
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(torch.bfloat16)
else:
self.bias = None
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda() if self.weight_name is not None else None
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
if not self.lazy_load:
if self.weight_name is not None:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
else:
self.weight = None
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else:
self.bias = None
def _calculate_size(self):
if self.bias is not None:
return self.weight.numel() * self.weight.element_size() + self.bias.numel() * self.bias.element_size()
return self.weight.numel() * self.weight.element_size()
def clear(self):
del self.weight
if self.bias is not None:
del self.bias
@abstractmethod
def apply(self, input_tensor):
......@@ -23,10 +61,15 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.config = config
def to_cpu(self, non_blocking=False):
if self.weight is not None:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
if self.weight is not None:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
if self.weight is not None:
......@@ -46,8 +89,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
@LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name, bias_name, eps=1e-6):
super().__init__(weight_name, bias_name, eps)
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
......
......@@ -5,13 +5,26 @@ import sgl_kernel
class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, eps=1e-6):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
self.weight_name = weight_name
self.eps = eps
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(torch.bfloat16)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
if not self.lazy_load:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def clear(self):
del self.weight
@abstractmethod
def apply(self, input_tensor):
......@@ -22,16 +35,22 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.config = config
def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
def _calculate_size(self):
return self.weight.numel() * self.weight.element_size()
@RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
......@@ -47,8 +66,8 @@ class RMSWeight(RMSWeightTemplate):
@RMS_WEIGHT_REGISTER("FP32")
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_tensor = input_tensor.float()
......@@ -60,8 +79,8 @@ class RMSWeightFP32(RMSWeight):
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps)
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_tensor = input_tensor.contiguous()
......
import torch
from lightx2v.utils.registry_factory import TENSOR_REGISTER
from safetensors import safe_open
@TENSOR_REGISTER("Default")
class DefaultTensor:
def __init__(self, tensor_name):
def __init__(self, tensor_name, lazy_load=False, lazy_load_file=None):
self.tensor_name = tensor_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
def load_from_disk(self):
if not torch._dynamo.is_compiling():
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16).pin_memory()
else:
self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(torch.bfloat16)
def load(self, weight_dict):
self.tensor = weight_dict[self.tensor_name]
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
if not self.lazy_load:
self.tensor = weight_dict[self.tensor_name].to(torch.bfloat16)
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def clear(self):
del self.tensor
def _calculate_size(self):
return self.tensor.numel() * self.tensor.element_size()
def to_cpu(self, non_blocking=False):
# self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
if hasattr(self, "pinned_tensor"):
self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
else:
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment