Commit acacd26f authored by helloyongyang's avatar helloyongyang
Browse files

Update configs & Using sgl-fp8 as default

parent 6de0a3b4
......@@ -15,7 +15,7 @@
"cpu_offload": false,
"use_31_block": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
......
......@@ -19,7 +19,7 @@
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
......
......@@ -27,6 +27,6 @@
"vae_cpu_offload": false,
"use_tiling_vae": true,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}
}
......@@ -19,7 +19,7 @@
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
......
......@@ -23,7 +23,7 @@
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
......
......@@ -24,7 +24,7 @@
"seq_p_attn_type": "ulysses"
},
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
......
......@@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights
......@@ -87,7 +87,7 @@ class T5Attention(nn.Module):
if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......@@ -154,7 +154,7 @@ class T5FeedForward(nn.Module):
if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......
......@@ -10,7 +10,7 @@ from loguru import logger
# from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights
__all__ = [
......@@ -62,7 +62,7 @@ class SelfAttention(nn.Module):
if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......@@ -140,7 +140,7 @@ class AttentionBlock(nn.Module):
if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8
linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f":
......
......@@ -217,7 +217,8 @@ class DefaultRunner(BaseRunner):
def run_main(self, total_steps=None):
self.init_run()
for segment_idx in range(self.video_segment_num):
with ProfilingContext4Debug(f"segment end2end {segment_idx}"):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"):
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
......
......@@ -456,7 +456,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.config.seed = self.config.seed + segment_idx
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"):
self.audio_encoder = self.load_audio_encoder()
......
......@@ -893,7 +893,7 @@ class WanVAE:
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table:
best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)]
logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
# logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
return best_h, best_w
best_h, best_w = 1, world_size
......@@ -908,7 +908,7 @@ class WanVAE:
if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff
best_h, best_w = h, w
logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
# logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w)
return best_h, best_w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment