Commit 7f2fbd00 authored by gushiqiao's avatar gushiqiao
Browse files

Add torchao kernel

parent 29a90944
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
},
"t5_quantized": true,
"t5_quant_scheme": "int8-torchao",
"clip_quantized": true,
"clip_quant_scheme": "int8-torchao"
}
......@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
- **CUDA**: 12.4 or higher version
- **Dependencies**: Refer to LightX2V project's requirements_win.txt
### Installation Steps
1. **Clone Project**
```cmd
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
2. **Install Dependencies**
```cmd
pip install -r requirements_win.txt
```
3. **Download Models**
Refer to [Model Download Guide](../getting_started/quickstart.md) to download required models
## 🎯 Usage Methods
### Method 1: Using Batch File Inference
......@@ -113,42 +97,4 @@ Double-click to run the `start_lightx2v.bat` file, the script will:
### Method 3: Using ComfyUI Inference
This guide will instruct you on how to download and use the portable version of the Lightx2v-ComfyUI environment, so you can avoid manual environment configuration steps. This is suitable for users who want to quickly start experiencing accelerated video generation with Lightx2v on Windows systems.
#### Download the Windows Portable Environment:
- [Baidu Cloud Download](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid), extraction code: wfid
The portable environment already packages all Python runtime dependencies, including the code and dependencies for ComfyUI and LightX2V. After downloading, simply extract to use.
After extraction, the directory structure is as follows:
```shell
lightx2v_env
├──📂 ComfyUI # ComfyUI code
├──📂 portable_python312_embed # Standalone Python environment
└── run_nvidia_gpu.bat # Windows startup script (double-click to start)
```
#### Start ComfyUI
Directly double-click the run_nvidia_gpu.bat file. The system will open a Command Prompt window and run the program. The first startup may take a while, please be patient. After startup is complete, the browser will automatically open and display the ComfyUI frontend interface.
![i2v example workflow](../../../../assets/figs/portabl_windows/pic1.png)
The plugin used by LightX2V-ComfyUI is [ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper). Example workflows can be obtained from this project.
#### Tested Graphics Cards (offload mode)
- Tested model: `Wan2.1-I2V-14B-480P`
| GPU Model | Task Type | VRAM Capacity | Actual Max VRAM Usage | Actual Max RAM Usage |
|:-----------|:------------|:--------------|:---------------------|:---------------------|
| 3090Ti | I2V | 24G | 6.1G | 7.1G |
| 3080Ti | I2V | 12G | 6.1G | 7.1G |
| 3060Ti | I2V | 8G | 6.1G | 7.1G |
#### Environment Packaging and Usage Reference
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2)
TODO - To be added ComfyUI integration guide
......@@ -25,6 +25,11 @@ try:
except ImportError:
deep_gemm = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
......@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# =========================
# act quant kernels
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
......@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
ops = None
class QuantLinearInt8(nn.Module):
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
......@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return self
class QuantLinearFp8(nn.Module):
class VllmQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
......@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
__all__ = [
......@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
# layers
......
......@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
......@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......
......@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel):
......@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(weight_dict.keys())
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
......
......@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
if clip_quantized:
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, clip_quant_scheme),
f"clip-{clip_quant_scheme}.pth",
os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{tmp_clip_quant_scheme}.pth",
),
)
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
print(clip_quant_scheme)
image_encoder = CLIPModel(
dtype=torch.float16,
device=self.init_device,
......@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth",
os.path.join(self.config.model_path, tmp_t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
),
)
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
print(t5_quant_scheme)
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment