Commit b959bfd9 authored by gushiqiao's avatar gushiqiao
Browse files

Update q8-kernel

parent 3e4fe79b
......@@ -140,6 +140,18 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
global_runner = None
current_config = None
cur_dit_quant_scheme = None
......@@ -506,7 +518,11 @@ def auto_configure(enable_auto_config, resolution):
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
......@@ -736,6 +752,30 @@ def main():
.warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
}
.auto-config-checkbox label {
font-size: 16px !important;
font-weight: bold !important;
color: #2c3e50 !important;
}
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} Video Generator")
......@@ -800,11 +840,14 @@ def main():
)
with gr.Column():
enable_auto_config = gr.Checkbox(
label="Auto-configure Inference Options",
value=False,
info="Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.",
)
with gr.Group():
gr.Markdown("### 🚀 **Smart Configuration Recommendation**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **Auto-configure Inference Options**",
value=False,
info="💡 **Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
......
......@@ -142,6 +142,18 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
global_runner = None
current_config = None
cur_dit_quant_scheme = None
......@@ -508,7 +520,11 @@ def auto_configure(enable_auto_config, resolution):
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
......@@ -738,6 +754,30 @@ def main():
.warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
}
.auto-config-checkbox label {
font-size: 16px !important;
font-weight: bold !important;
color: #2c3e50 !important;
}
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} 视频生成器")
......@@ -802,9 +842,14 @@ def main():
)
with gr.Column():
enable_auto_config = gr.Checkbox(
label="自动配置推理选项", value=False, info="自动优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。"
)
with gr.Group():
gr.Markdown("### 🚀 **智能配置推荐**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **自动配置推理选项**",
value=False,
info="💡 **智能优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
......
......@@ -18,7 +18,7 @@ lightx2v_path=/path/to/lightx2v
# Model path configuration
# Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill
i2v_model_path=/Wan_0718/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/
# Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B
......@@ -222,7 +222,7 @@ fi
echo "🎬 Starting Gradio demo..."
echo "📱 Please access in browser: http://$server_name:$server_port"
echo "⏹️ Press Ctrl+C to stop service"
echo "🔄 First startup may take several minutes to load model..."
echo "🔄 First startup may take several minutes to load resources..."
echo "=========================================="
# Start Python demo
......
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "int8-q8f",
"clip_quantized": true,
"clip_quant_scheme": "int8-q8f"
}
......@@ -156,11 +156,13 @@ VAE (Variational Autoencoder) is a key component in video generation, and optimi
use_tiling_vae = True # Enable VAE chunked inference
```
#### [Lightweight VAE](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
#### Lightweight VAE
You can download it here: https://github.com/madebyollin/taehv/blob/main/taew2_1.pth
```python
# VAE optimization configuration
use_tiny_vae = True # Use lightweight VAE
tiny_vae = True # Use lightweight VAE
```
**VAE Optimization Effects**:
......
......@@ -156,11 +156,13 @@ VAE (变分自编码器) 是视频生成的关键组件,优化VAE可以显著
use_tiling_vae = True # 启用VAE分块推理
```
#### [轻量级VAE](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth)
#### 轻量级VAE
可以在这里下载:https://github.com/madebyollin/taehv/blob/main/taew2_1.pth
```python
# VAE优化配置
use_tiny_vae = True # 使用轻量级VAE
tiny_vae = True # 使用轻量级VAE
```
**VAE优化效果**:
......
......@@ -11,6 +11,11 @@ try:
except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
......@@ -152,3 +157,66 @@ class TorchaoQuantLinearInt8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class Q8FQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.q8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
out_dtype=torch.bfloat16,
)
return output_tensor
class Q8FQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x.squeeze(0), None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.fp8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
out_dtype=torch.bfloat16,
)
return output_tensor
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
__all__ = [
......@@ -88,6 +88,10 @@ class T5Attention(nn.Module):
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
linear_cls = nn.Linear
......@@ -151,6 +155,10 @@ class T5FeedForward(nn.Module):
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
linear_cls = nn.Linear
# layers
......
......@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
......@@ -68,6 +68,10 @@ class SelfAttention(nn.Module):
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
linear_cls = nn.Linear
......@@ -142,6 +146,10 @@ class AttentionBlock(nn.Module):
linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
linear_cls = nn.Linear
......
......@@ -41,6 +41,7 @@ class WanModel:
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme))
else:
self.dit_quantized_ckpt = None
self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
......
......@@ -143,8 +143,9 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False),
}
if self.config.get("tiny_vae", False):
tiny_vae_path = self.config.get("tiny_vae_path", os.path.join(self.config.model_path, "taew2_1.pth"))
vae_decoder = WanVAE_tiny(
vae_pth=self.config.tiny_vae_path,
vae_pth=tiny_vae_path,
device=self.init_device,
).to("cuda")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment