Unverified Commit 6062ef24 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Enable T5 inference and offload overlap for improved efficiency (#423)


Co-authored-by: default avatargushiqiao <975033167@qq.ocm>
parent d0a5c78d
......@@ -10,7 +10,6 @@
"sample_shift": 5,
"enable_cfg": true,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl",
"unload_modules": false,
......
......@@ -11,7 +11,6 @@
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true,
"dit_quant_scheme": "fp8-vllm",
......
......@@ -11,7 +11,6 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true,
"dit_quant_scheme": "fp8-vllm",
......
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"clip_cpu_offload": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"vae_cpu_offload": false
}
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -16,7 +16,6 @@
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl",
"clip_cpu_offload": false,
......
......@@ -22,7 +22,6 @@
"clip_cpu_offload": false,
"vae_cpu_offload": false,
"offload_ratio": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false
......
from .attn import *
from .conv import *
from .embedding import *
from .mm import *
from .norm import *
from .tensor import *
from .embedding_weight import *
from abc import ABCMeta
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
def load(self, weight_dict):
if not self.lazy_load:
if self.weight_name is not None:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
else:
self.weight = None
del weight_dict[self.weight_name]
def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, lazy_load, lazy_load_file)
def apply(self, input_indices):
output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
return output
......@@ -112,6 +112,8 @@ class MMWeight(MMWeightTemplate):
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype
......@@ -124,6 +126,7 @@ class MMWeight(MMWeightTemplate):
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
......
......@@ -128,7 +128,6 @@ class WanRunner(DefaultRunner):
tokenizer_path=tokenizer_path,
shard_fn=None,
cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), # support ['model', 'block']
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
......
......@@ -52,5 +52,5 @@ CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register()
TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER = Register()
EMBEDDING_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment