Unverified Commit 6062ef24 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Enable T5 inference and offload overlap for improved efficiency (#423)


Co-authored-by: default avatargushiqiao <975033167@qq.ocm>
parent d0a5c78d
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": true, "enable_cfg": true,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8-sgl", "t5_quant_scheme": "fp8-sgl",
"unload_modules": false, "unload_modules": false,
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model", "dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-vllm", "dit_quant_scheme": "fp8-vllm",
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model", "dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true, "dit_quantized": true,
"dit_quant_scheme": "fp8-vllm", "dit_quant_scheme": "fp8-vllm",
......
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"clip_cpu_offload": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"vae_cpu_offload": false
}
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"use_31_block": false, "use_31_block": false,
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"use_31_block": false, "use_31_block": false,
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"use_31_block": false, "use_31_block": false,
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8", "adapter_quant_scheme": "fp8",
"cpu_offload": false, "cpu_offload": false,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true, "clip_cpu_offload": true,
"vae_cpu_offload": true, "vae_cpu_offload": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"offload_granularity": "block", "offload_granularity": "block",
"offload_ratio": 1, "offload_ratio": 1,
"t5_cpu_offload": true, "t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8-sgl", "t5_quant_scheme": "fp8-sgl",
"clip_cpu_offload": false, "clip_cpu_offload": false,
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
"clip_cpu_offload": false, "clip_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"offload_ratio": 1, "offload_ratio": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true, "use_tiling_vae": true,
"audio_encoder_cpu_offload": true, "audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false "audio_adapter_cpu_offload": false
......
from .attn import * from .attn import *
from .conv import * from .conv import *
from .embedding import *
from .mm import * from .mm import *
from .norm import * from .norm import *
from .tensor import * from .tensor import *
from .embedding_weight import *
from abc import ABCMeta
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
def load(self, weight_dict):
if not self.lazy_load:
if self.weight_name is not None:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
else:
self.weight = None
del weight_dict[self.weight_name]
def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, lazy_load, lazy_load_file)
def apply(self, input_indices):
output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
return output
...@@ -112,6 +112,8 @@ class MMWeight(MMWeightTemplate): ...@@ -112,6 +112,8 @@ class MMWeight(MMWeightTemplate):
self.weight = weight_dict[self.weight_name].t() self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu": elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].t().shape weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype weight_dtype = weight_dict[self.weight_name].dtype
...@@ -124,6 +126,7 @@ class MMWeight(MMWeightTemplate): ...@@ -124,6 +126,7 @@ class MMWeight(MMWeightTemplate):
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype) self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name]) self.pin_bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None
self.pin_bias = None self.pin_bias = None
del weight_dict[self.weight_name] del weight_dict[self.weight_name]
else: else:
......
...@@ -128,7 +128,6 @@ class WanRunner(DefaultRunner): ...@@ -128,7 +128,6 @@ class WanRunner(DefaultRunner):
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
shard_fn=None, shard_fn=None,
cpu_offload=t5_offload, cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), # support ['model', 'block']
t5_quantized=t5_quantized, t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt, t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme, quant_scheme=t5_quant_scheme,
......
...@@ -52,5 +52,5 @@ CONV3D_WEIGHT_REGISTER = Register() ...@@ -52,5 +52,5 @@ CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register() CONV2D_WEIGHT_REGISTER = Register()
TENSOR_REGISTER = Register() TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER = Register() CONVERT_WEIGHT_REGISTER = Register()
EMBEDDING_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register() RUNNER_REGISTER = Register()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment