Commit acacd26f authored by helloyongyang's avatar helloyongyang
Browse files

Update configs & Using sgl-fp8 as default

parent 6de0a3b4
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"cpu_offload": false, "cpu_offload": false,
"use_31_block": false, "use_31_block": false,
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}, },
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
"seq_p_attn_type": "ulysses" "seq_p_attn_type": "ulysses"
}, },
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}, },
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
......
...@@ -27,6 +27,6 @@ ...@@ -27,6 +27,6 @@
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_tiling_vae": true, "use_tiling_vae": true,
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
} }
} }
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
"seq_p_attn_type": "ulysses" "seq_p_attn_type": "ulysses"
}, },
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}, },
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
"seq_p_attn_type": "ulysses" "seq_p_attn_type": "ulysses"
}, },
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}, },
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"seq_p_attn_type": "ulysses" "seq_p_attn_type": "ulysses"
}, },
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
}, },
"adapter_quantized": true, "adapter_quantized": true,
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
......
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
...@@ -87,7 +87,7 @@ class T5Attention(nn.Module): ...@@ -87,7 +87,7 @@ class T5Attention(nn.Module):
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f": elif quant_scheme == "int8-q8f":
...@@ -154,7 +154,7 @@ class T5FeedForward(nn.Module): ...@@ -154,7 +154,7 @@ class T5FeedForward(nn.Module):
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f": elif quant_scheme == "int8-q8f":
......
...@@ -10,7 +10,7 @@ from loguru import logger ...@@ -10,7 +10,7 @@ from loguru import logger
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.common.ops.attn import TorchSDPAWeight from lightx2v.common.ops.attn import TorchSDPAWeight
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, TorchaoQuantLinearInt8, VllmQuantLinearFp8, VllmQuantLinearInt8 from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
__all__ = [ __all__ = [
...@@ -62,7 +62,7 @@ class SelfAttention(nn.Module): ...@@ -62,7 +62,7 @@ class SelfAttention(nn.Module):
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f": elif quant_scheme == "int8-q8f":
...@@ -140,7 +140,7 @@ class AttentionBlock(nn.Module): ...@@ -140,7 +140,7 @@ class AttentionBlock(nn.Module):
if quant_scheme == "int8": if quant_scheme == "int8":
linear_cls = VllmQuantLinearInt8 linear_cls = VllmQuantLinearInt8
elif quant_scheme == "fp8": elif quant_scheme == "fp8":
linear_cls = VllmQuantLinearFp8 linear_cls = SglQuantLinearFp8
elif quant_scheme == "int8-torchao": elif quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8 linear_cls = TorchaoQuantLinearInt8
elif quant_scheme == "int8-q8f": elif quant_scheme == "int8-q8f":
......
...@@ -217,7 +217,8 @@ class DefaultRunner(BaseRunner): ...@@ -217,7 +217,8 @@ class DefaultRunner(BaseRunner):
def run_main(self, total_steps=None): def run_main(self, total_steps=None):
self.init_run() self.init_run()
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
with ProfilingContext4Debug(f"segment end2end {segment_idx}"): logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"):
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
# 2. main inference loop # 2. main inference loop
......
...@@ -456,7 +456,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -456,7 +456,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.config.seed = self.config.seed + segment_idx self.config.seed = self.config.seed + segment_idx
torch.manual_seed(self.config.seed) torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}") # logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"): if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"):
self.audio_encoder = self.load_audio_encoder() self.audio_encoder = self.load_audio_encoder()
......
...@@ -893,7 +893,7 @@ class WanVAE: ...@@ -893,7 +893,7 @@ class WanVAE:
def _calculate_2d_grid(self, latent_height, latent_width, world_size): def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table: if (latent_height, latent_width, world_size) in self.grid_table:
best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)] best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)]
logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent") # logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
return best_h, best_w return best_h, best_w
best_h, best_w = 1, world_size best_h, best_w = 1, world_size
...@@ -908,7 +908,7 @@ class WanVAE: ...@@ -908,7 +908,7 @@ class WanVAE:
if aspect_diff < min_aspect_diff: if aspect_diff < min_aspect_diff:
min_aspect_diff = aspect_diff min_aspect_diff = aspect_diff
best_h, best_w = h, w best_h, best_w = h, w
logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent") # logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w) self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w)
return best_h, best_w return best_h, best_w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment