Commit b959bfd9 authored by gushiqiao's avatar gushiqiao
Browse files

Update q8-kernel

parent 3e4fe79b
...@@ -140,6 +140,18 @@ def is_fp8_supported_gpu(): ...@@ -140,6 +140,18 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None cur_dit_quant_scheme = None
...@@ -506,7 +518,11 @@ def auto_configure(enable_auto_config, resolution): ...@@ -506,7 +518,11 @@ def auto_configure(enable_auto_config, resolution):
quant_type = "int8" quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority: for op in attn_priority:
if dict(available_attn_ops).get(op): if dict(available_attn_ops).get(op):
...@@ -736,6 +752,30 @@ def main(): ...@@ -736,6 +752,30 @@ def main():
.warning { color: #ff6b6b; font-weight: bold; } .warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; } .advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; } .tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
}
.auto-config-checkbox label {
font-size: 16px !important;
font-weight: bold !important;
color: #2c3e50 !important;
}
""", """,
) as demo: ) as demo:
gr.Markdown(f"# 🎬 {model_cls} Video Generator") gr.Markdown(f"# 🎬 {model_cls} Video Generator")
...@@ -800,11 +840,14 @@ def main(): ...@@ -800,11 +840,14 @@ def main():
) )
with gr.Column(): with gr.Column():
enable_auto_config = gr.Checkbox( with gr.Group():
label="Auto-configure Inference Options", gr.Markdown("### 🚀 **Smart Configuration Recommendation**", elem_classes=["auto-config-title"])
value=False, enable_auto_config = gr.Checkbox(
info="Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.", label="🎯 **Auto-configure Inference Options**",
) value=False,
info="💡 **Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9): with gr.Column(scale=9):
seed = gr.Slider( seed = gr.Slider(
label="Random Seed", label="Random Seed",
......
...@@ -142,6 +142,18 @@ def is_fp8_supported_gpu(): ...@@ -142,6 +142,18 @@ def is_fp8_supported_gpu():
return (major == 8 and minor == 9) or (major >= 9) return (major == 8 and minor == 9) or (major >= 9)
def is_ada_architecture_gpu():
if not torch.cuda.is_available():
return False
try:
gpu_name = torch.cuda.get_device_name(0).upper()
ada_keywords = ["RTX 40", "RTX40", "4090", "4080", "4070", "4060"]
return any(keyword in gpu_name for keyword in ada_keywords)
except Exception as e:
logger.warning(f"Failed to get GPU name: {e}")
return False
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None cur_dit_quant_scheme = None
...@@ -508,7 +520,11 @@ def auto_configure(enable_auto_config, resolution): ...@@ -508,7 +520,11 @@ def auto_configure(enable_auto_config, resolution):
quant_type = "int8" quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"] attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_op_priority = ["sgl", "vllm", "q8f"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority: for op in attn_priority:
if dict(available_attn_ops).get(op): if dict(available_attn_ops).get(op):
...@@ -738,6 +754,30 @@ def main(): ...@@ -738,6 +754,30 @@ def main():
.warning { color: #ff6b6b; font-weight: bold; } .warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; } .advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; } .tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
}
.auto-config-checkbox label {
font-size: 16px !important;
font-weight: bold !important;
color: #2c3e50 !important;
}
""", """,
) as demo: ) as demo:
gr.Markdown(f"# 🎬 {model_cls} 视频生成器") gr.Markdown(f"# 🎬 {model_cls} 视频生成器")
...@@ -802,9 +842,14 @@ def main(): ...@@ -802,9 +842,14 @@ def main():
) )
with gr.Column(): with gr.Column():
enable_auto_config = gr.Checkbox( with gr.Group():
label="自动配置推理选项", value=False, info="自动优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。" gr.Markdown("### 🚀 **智能配置推荐**", elem_classes=["auto-config-title"])
) enable_auto_config = gr.Checkbox(
label="🎯 **自动配置推理选项**",
value=False,
info="💡 **智能优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9): with gr.Column(scale=9):
seed = gr.Slider( seed = gr.Slider(
label="随机种子", label="随机种子",
......
...@@ -18,7 +18,7 @@ lightx2v_path=/path/to/lightx2v ...@@ -18,7 +18,7 @@ lightx2v_path=/path/to/lightx2v
# Model path configuration # Model path configuration
# Image-to-video model path (for i2v tasks) # Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v # Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill i2v_model_path=/Wan_0718/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/
# Text-to-video model path (for t2v tasks) # Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B # Example: /path/to/Wan2.1-T2V-1.3B
...@@ -222,7 +222,7 @@ fi ...@@ -222,7 +222,7 @@ fi
echo "🎬 Starting Gradio demo..." echo "🎬 Starting Gradio demo..."
echo "📱 Please access in browser: http://$server_name:$server_port" echo "📱 Please access in browser: http://$server_name:$server_port"
echo "⏹️ Press Ctrl+C to stop service" echo "⏹️ Press Ctrl+C to stop service"
echo "🔄 First startup may take several minutes to load model..." echo "🔄 First startup may take several minutes to load resources..."
echo "==========================================" echo "=========================================="
# Start Python demo # Start Python demo
......
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "int8-q8f",
"clip_quantized": true,
"clip_quant_scheme": "int8-q8f"
}
...@@ -156,11 +156,13 @@ VAE (Variational Autoencoder) is a key component in video generation, and optimi ...@@ -156,11 +156,13 @@ VAE (Variational Autoencoder) is a key component in video generation, and optimi
use_tiling_vae = True # Enable VAE chunked inference use_tiling_vae = True # Enable VAE chunked inference
``` ```
#### [Lightweight VAE](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth) #### Lightweight VAE
You can download it here: https://github.com/madebyollin/taehv/blob/main/taew2_1.pth
```python ```python
# VAE optimization configuration # VAE optimization configuration
use_tiny_vae = True # Use lightweight VAE tiny_vae = True # Use lightweight VAE
``` ```
**VAE Optimization Effects**: **VAE Optimization Effects**:
......
...@@ -156,11 +156,13 @@ VAE (变分自编码器) 是视频生成的关键组件,优化VAE可以显著 ...@@ -156,11 +156,13 @@ VAE (变分自编码器) 是视频生成的关键组件,优化VAE可以显著
use_tiling_vae = True # 启用VAE分块推理 use_tiling_vae = True # 启用VAE分块推理
``` ```
#### [轻量级VAE](https://github.com/madebyollin/taehv/blob/main/taew2_1.pth) #### 轻量级VAE
可以在这里下载:https://github.com/madebyollin/taehv/blob/main/taew2_1.pth
```python ```python
# VAE优化配置 # VAE优化配置
use_tiny_vae = True # 使用轻量级VAE tiny_vae = True # 使用轻量级VAE
``` ```
**VAE优化效果**: **VAE优化效果**:
......
...@@ -11,6 +11,11 @@ try: ...@@ -11,6 +11,11 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class VllmQuantLinearInt8(nn.Module): class VllmQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
...@@ -152,3 +157,66 @@ class TorchaoQuantLinearInt8(nn.Module): ...@@ -152,3 +157,66 @@ class TorchaoQuantLinearInt8(nn.Module):
self.weight_scale = maybe_cast(self.weight_scale) self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias) self.bias = maybe_cast(self.bias)
return self return self
class Q8FQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.q8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
out_dtype=torch.bfloat16,
)
return output_tensor
class Q8FQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x.squeeze(0), None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def forward(self, x):
input_tensor_quant, input_tensor_scale = self.act_quant_func(x)
output_tensor = Q8F.linear.fp8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
out_dtype=torch.bfloat16,
)
return output_tensor
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from .tokenizer import HuggingfaceTokenizer from .tokenizer import HuggingfaceTokenizer
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
__all__ = [ __all__ = [
...@@ -88,6 +88,10 @@ class T5Attention(nn.Module): ...@@ -88,6 +88,10 @@ class T5Attention(nn.Module):
linear_cls = VllmQuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -151,6 +155,10 @@ class T5FeedForward(nn.Module): ...@@ -151,6 +155,10 @@ class T5FeedForward(nn.Module):
linear_cls = VllmQuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
# layers # layers
......
...@@ -10,7 +10,7 @@ import torchvision.transforms as T ...@@ -10,7 +10,7 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import VllmQuantLinearInt8, VllmQuantLinearFp8, TorchaoQuantLinearInt8, Q8FQuantLinearInt8, Q8FQuantLinearFp8
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from transformers import CLIPVisionModel from transformers import CLIPVisionModel
...@@ -68,6 +68,10 @@ class SelfAttention(nn.Module): ...@@ -68,6 +68,10 @@ class SelfAttention(nn.Module):
linear_cls = VllmQuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
...@@ -142,6 +146,10 @@ class AttentionBlock(nn.Module): ...@@ -142,6 +146,10 @@ class AttentionBlock(nn.Module):
linear_cls = VllmQuantLinearFp8 linear_cls = VllmQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else: else:
linear_cls = nn.Linear linear_cls = nn.Linear
......
...@@ -41,6 +41,7 @@ class WanModel: ...@@ -41,6 +41,7 @@ class WanModel:
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme)) self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", os.path.join(model_path, dit_quant_scheme))
else: else:
self.dit_quantized_ckpt = None self.dit_quantized_ckpt = None
self.config.dit_quantized_ckpt = self.dit_quantized_ckpt
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
......
...@@ -143,8 +143,9 @@ class WanRunner(DefaultRunner): ...@@ -143,8 +143,9 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
} }
if self.config.get("tiny_vae", False): if self.config.get("tiny_vae", False):
tiny_vae_path = self.config.get("tiny_vae_path", os.path.join(self.config.model_path, "taew2_1.pth"))
vae_decoder = WanVAE_tiny( vae_decoder = WanVAE_tiny(
vae_pth=self.config.tiny_vae_path, vae_pth=tiny_vae_path,
device=self.init_device, device=self.init_device,
).to("cuda") ).to("cuda")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment