"vscode:/vscode.git/clone" did not exist on "550b6a306564a06b821d76bebe6768e3c707c766"
Commit cc2a283a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Add torchao kernel

Dev quant
parents 29a90944 adf8df9d
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao"
},
"t5_quantized": true,
"t5_quant_scheme": "int8-torchao",
"clip_quantized": true,
"clip_quant_scheme": "int8-torchao"
}
......@@ -19,22 +19,6 @@ This document provides detailed instructions for deploying LightX2V locally on W
- **CUDA**: 12.4 or higher version
- **Dependencies**: Refer to LightX2V project's requirements_win.txt
### Installation Steps
1. **Clone Project**
```cmd
git clone https://github.com/ModelTC/LightX2V.git
cd LightX2V
```
2. **Install Dependencies**
```cmd
pip install -r requirements_win.txt
```
3. **Download Models**
Refer to [Model Download Guide](../getting_started/quickstart.md) to download required models
## 🎯 Usage Methods
### Method 1: Using Batch File Inference
......
......@@ -25,6 +25,11 @@ try:
except ImportError:
deep_gemm = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
......@@ -232,6 +237,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# =========================
# act quant kernels
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
......@@ -624,6 +632,33 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
ops = None
class QuantLinearInt8(nn.Module):
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
......@@ -54,7 +63,7 @@ class QuantLinearInt8(nn.Module):
return self
class QuantLinearFp8(nn.Module):
class VllmQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
......@@ -101,3 +110,45 @@ class QuantLinearFp8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight.t(), self.weight_scale.t().float(), output_dtype=torch.bfloat16)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
__all__ = [
......@@ -83,9 +83,11 @@ class T5Attention(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......@@ -144,9 +146,11 @@ class T5FeedForward(nn.Module):
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
# layers
......
......@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
......@@ -63,9 +63,11 @@ class SelfAttention(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......@@ -135,9 +137,11 @@ class AttentionBlock(nn.Module):
# layers
if quantized:
if quant_scheme == "int8":
linear_cls = QuantLinearInt8
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = QuantLinearFp8
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
else:
linear_cls = nn.Linear
......
......@@ -11,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from loguru import logger
class WanDistillModel(WanModel):
......@@ -31,8 +32,11 @@ class WanDistillModel(WanModel):
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(weight_dict.keys())
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
......
......@@ -57,17 +57,19 @@ class WanRunner(DefaultRunner):
if clip_quantized:
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_quantized_ckpt = self.config.get(
"clip_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, clip_quant_scheme),
f"clip-{clip_quant_scheme}.pth",
os.path.join(self.config.model_path, tmp_clip_quant_scheme),
f"clip-{tmp_clip_quant_scheme}.pth",
),
)
else:
clip_quantized_ckpt = None
clip_quant_scheme = None
print(clip_quant_scheme)
image_encoder = CLIPModel(
dtype=torch.float16,
device=self.init_device,
......@@ -93,18 +95,19 @@ class WanRunner(DefaultRunner):
t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized:
t5_quant_scheme = self.config.get("t5_quant_scheme", None)
tmp_t5_quant_scheme = t5_quant_scheme.split("-")[0]
assert t5_quant_scheme is not None
t5_quantized_ckpt = self.config.get(
"t5_quantized_ckpt",
os.path.join(
os.path.join(self.config.model_path, t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth",
os.path.join(self.config.model_path, tmp_t5_quant_scheme),
f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth",
),
)
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
print(t5_quant_scheme)
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment