Commit 7f2fbd00 authored by gushiqiao's avatar gushiqiao
Browse files

Add torchao kernel

parent 29a90944
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
},
"t5_quantized": true,
"t5_quant_scheme": "int8-torchao",
"clip_quantized": true,
"clip_quant_scheme": "int8-torchao"
}
...@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W ...@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
- **CUDA**: 12.4 or higher version - **CUDA**: 12.4 or higher version
- **Dependencies**: Refer to LightX2V project's requirements_win.txt - **Dependencies**: Refer to LightX2V project's requirements_win.txt
### Installation Steps
1. **Clone Project**
```cmd
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
2. **Install Dependencies**
```cmd
pip install -r requirements_win.txt
```
3. **Download Models**
Refer to [Model Download Guide](../getting_started/quickstart.md) to download required models
## 🎯 Usage Methods ## 🎯 Usage Methods
### Method 1: Using Batch File Inference ### Method 1: Using Batch File Inference
...@@ -113,42 +97,4 @@ Double-click to run the `start_lightx2v.bat` file, the script will: ...@@ -113,42 +97,4 @@ Double-click to run the `start_lightx2v.bat` file, the script will:
### Method 3: Using ComfyUI Inference ### Method 3: Using ComfyUI Inference
This guide will instruct you on how to download and use the portable version of the Lightx2v-ComfyUI environment, so you can avoid manual environment configuration steps. This is suitable for users who want to quickly start experiencing accelerated video generation with Lightx2v on Windows systems. TODO - To be added ComfyUI integration guide
#### Download the Windows Portable Environment:
- [Baidu Cloud Download](https://pan.baidu.com/s/1FVlicTXjmXJA1tAVvNCrBw?pwd=wfid), extraction code: wfid
The portable environment already packages all Python runtime dependencies, including the code and dependencies for ComfyUI and LightX2V. After downloading, simply extract to use.
After extraction, the directory structure is as follows:
```shell
lightx2v_env
├──📂 ComfyUI # ComfyUI code
├──📂 portable_python312_embed # Standalone Python environment
└── run_nvidia_gpu.bat # Windows startup script (double-click to start)
```
#### Start ComfyUI
Directly double-click the run_nvidia_gpu.bat file. The system will open a Command Prompt window and run the program. The first startup may take a while, please be patient. After startup is complete, the browser will automatically open and display the ComfyUI frontend interface.
![i2v example workflow](../../../../assets/figs/portabl_windows/pic1.png)
The plugin used by LightX2V-ComfyUI is [ComfyUI-Lightx2vWrapper](https://github.com/ModelTC/ComfyUI-Lightx2vWrapper). Example workflows can be obtained from this project.
#### Tested Graphics Cards (offload mode)
- Tested model: `Wan2.1-I2V-14B-480P`
| GPU Model | Task Type | VRAM Capacity | Actual Max VRAM Usage | Actual Max RAM Usage |
|:-----------|:------------|:--------------|:---------------------|:---------------------|
| 3090Ti | I2V | 24G | 6.1G | 7.1G |
| 3080Ti | I2V | 12G | 6.1G | 7.1G |
| 3060Ti | I2V | 8G | 6.1G | 7.1G |
#### Environment Packaging and Usage Reference
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
- [Portable-Windows-ComfyUI-Docs](https://docs.comfy.org/zh-CN/installation/comfyui_portable_windows#portable-%E5%8F%8A%E8%87%AA%E9%83%A8%E7%BD%B2)
...@@ -25,6 +25,11 @@ try: ...@@ -25,6 +25,11 @@ try:
except ImportError: except ImportError:
deep_gemm = None deep_gemm = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
...@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# ========================= # =========================
# act quant kernels # act quant kernels
# ========================= # =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_vllm(self, x): def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
...@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
ops = None
class QuantLinearInt8(nn.Module): try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module): ...@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return self return self
class QuantLinearFp8(nn.Module): class VllmQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module): ...@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale) self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias) self.bias = maybe_cast(self.bias)
return self return self
class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
__all__ = [ __all__ = [
...@@ -83,9 +83,11 @@ class T5Attention(nn.Module): ...@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module): ...@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
# layers # layers
......
...@@ -10,7 +10,7 @@ import torchvision.transforms as T ...@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from transformers import CLIPVisionModel from transformers import CLIPVisionModel
...@@ -63,9 +63,11 @@ class SelfAttention(nn.Module): ...@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module): ...@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers # layers
if quantized: if quantized:
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = QuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
......
...@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel): class WanDistillModel(WanModel):
...@@ -31,8 +32,11 @@ class WanDistillModel(WanModel): ...@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
return weight_dict return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt") ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(weight_dict.keys())
weight_dict = { weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
} }
......
...@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner): ...@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
if clip_quantized: if clip_quantized:
clip_quant_scheme = self.config.get("clip_quant_scheme", None) clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get( clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt", "clip_quantized_ckpt",
os.path.join( os.path.join(
os.path.join(self.config.model_path, clip_quant_scheme), os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{clip_quant_scheme}.pth", f"clip-{tmp_clip_quant_scheme}.pth",
), ),
) )
else: else:
clip_quantized_ckpt = None clip_quantized_ckpt = None
clip_quant_scheme = None clip_quant_scheme = None
print(clip_quant_scheme)
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=self.init_device, device=self.init_device,
...@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner): ...@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False) t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized: if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None) t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get( t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt", "t5_quantized_ckpt",
os.path.join( os.path.join(
os.path.join(self.config.model_path, t5_quant_scheme), os.path.join(self.config.model_path, tmp_t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth", f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
), ),
) )
else: else:
t5_quant_scheme = None t5_quant_scheme = None
t5_quantized_ckpt = None t5_quantized_ckpt = None
print(t5_quant_scheme)
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment