Unverified Commit 6062ef24 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Enable T5 inference and offload overlap for improved efficiency (#423)


Co-authored-by: default avatargushiqiao <975033167@qq.ocm>
parent d0a5c78d
......@@ -10,7 +10,6 @@
"sample_shift": 5,
"enable_cfg": true,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl",
"unload_modules": false,
......
......@@ -11,7 +11,6 @@
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true,
"dit_quant_scheme": "fp8-vllm",
......
......@@ -11,7 +11,6 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_quant_model",
"dit_quantized": true,
"dit_quant_scheme": "fp8-vllm",
......
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"clip_cpu_offload": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"vae_cpu_offload": false
}
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -14,7 +14,6 @@
"use_31_block": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -20,7 +20,6 @@
"adapter_quant_scheme": "fp8",
"cpu_offload": false,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"clip_cpu_offload": true,
"vae_cpu_offload": true,
"audio_encoder_cpu_offload": true,
......
......@@ -16,7 +16,6 @@
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl",
"clip_cpu_offload": false,
......
......@@ -22,7 +22,6 @@
"clip_cpu_offload": false,
"vae_cpu_offload": false,
"offload_ratio": 1,
"t5_offload_granularity": "block",
"use_tiling_vae": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false
......
from .attn import *
from .conv import *
from .embedding import *
from .mm import *
from .norm import *
from .tensor import *
from .embedding_weight import *
from abc import ABCMeta
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
def load(self, weight_dict):
if not self.lazy_load:
if self.weight_name is not None:
self.weight = weight_dict[self.weight_name]
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
else:
self.weight = None
del weight_dict[self.weight_name]
def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, lazy_load, lazy_load_file)
def apply(self, input_indices):
output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
return output
......@@ -112,6 +112,8 @@ class MMWeight(MMWeightTemplate):
self.weight = weight_dict[self.weight_name].t()
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
elif device.type == "cpu":
weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype
......@@ -124,6 +126,7 @@ class MMWeight(MMWeightTemplate):
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
......
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# 1. 标准库导入
import gc
import math
import os
import sys
from pathlib import Path
# 2. 第三方库导入
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import Q8FQuantLinearFp8, Q8FQuantLinearInt8, SglQuantLinearFp8, TorchaoQuantLinearInt8, VllmQuantLinearInt8
from lightx2v.utils.envs import *
from lightx2v.utils.utils import load_weights
from .tokenizer import HuggingfaceTokenizer
current_dir = Path(__file__).resolve().parent
project_root = current_dir.parent.parent.parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList # noqa E402
from lightx2v.common.offload.manager import WeightAsyncStreamManager # noqa E402
from lightx2v.common.ops import * # noqa E402
from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
Q8FQuantLinearFp8, # noqa E402
Q8FQuantLinearInt8, # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402
)
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
from lightx2v.utils.envs import * # noqa E402
from lightx2v.utils.registry_factory import ( # noqa E402
EMBEDDING_WEIGHT_REGISTER, # noqa E402
MM_WEIGHT_REGISTER, # noqa E402
RMS_WEIGHT_REGISTER, # noqa E402
)
from lightx2v.utils.utils import load_weights # noqa E402
__all__ = [
"T5Model",
......@@ -22,6 +44,103 @@ __all__ = [
]
class T5OffloadBlocksWeights(WeightModule):
def __init__(self, block_nums, mm_type):
super().__init__()
self.block_nums = block_nums
self.blocks = WeightModuleList([T5OffloadSelfAttention(i, mm_type) for i in range(block_nums)])
self.add_module("blocks", self.blocks)
class T5OffloadSelfAttention(WeightModule):
def __init__(self, block_index, mm_type, block_prefix="blocks"):
super().__init__()
self.block_index = block_index
if mm_type is None:
mm_type = "Default"
self.mm_type = mm_type
self.add_module(
"norm1",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{self.block_index}.norm1.weight",
),
)
self.add_module(
"norm2",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{self.block_index}.norm2.weight",
),
)
self.add_module(
"pos_embedding",
EMBEDDING_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.pos_embedding.embedding.weight",
),
)
self.compute_phases = WeightModuleList(
[
T5OffloadAttention(
block_index,
block_prefix,
mm_type,
),
T5OffloadFeedForward(
block_index,
block_prefix,
mm_type,
),
]
)
self.add_module("compute_phases", self.compute_phases)
class T5OffloadAttention(WeightModule):
def __init__(self, block_index, block_prefix, mm_type):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.add_module(
"attn_q",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.q.weight", None),
)
self.add_module(
"attn_k",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.k.weight", None),
)
self.add_module(
"attn_v",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.v.weight", None),
)
self.add_module(
"attn_o",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.attn.o.weight", None),
)
class T5OffloadFeedForward(WeightModule):
def __init__(self, block_index, block_prefix, mm_type):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.add_module(
"ffn_fc1",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc1.weight", None),
)
self.add_module(
"ffn_fc2",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.fc2.weight", None),
)
self.add_module(
"ffn_gate_0",
MM_WEIGHT_REGISTER[self.mm_type](f"{block_prefix}.{self.block_index}.ffn.gate.0.weight", None),
)
self.gelu = GELU()
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
......@@ -29,14 +148,6 @@ def fp16_clamp(x):
return x
def optimize_memory_usage():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
......@@ -75,7 +186,16 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
def __init__(
self,
dim,
dim_attn,
num_heads,
dropout=0.1,
quantized=False,
quant_scheme=None,
dtype=torch.bfloat16,
):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
......@@ -132,22 +252,24 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
def __init__(
self,
dim,
dim_ffn,
dropout=0.1,
quantized=False,
quant_scheme=None,
dtype=torch.bfloat16,
):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
......@@ -169,19 +291,13 @@ class T5FeedForward(nn.Module):
linear_cls = nn.Linear
# layers
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
if hasattr(self, "cpu_offload") and self.cpu_offload:
gate_out = self.gate(x)
fc1_out = self.fc1(x)
x = fc1_out * gate_out
del gate_out, fc1_out
else:
x = self.fc1(x) * self.gate(x)
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
......@@ -189,7 +305,19 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
def __init__(
self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1,
quantized=False,
quant_scheme=None,
dtype=torch.bfloat16,
):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
......@@ -207,18 +335,8 @@ class T5SelfAttention(nn.Module):
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
if hasattr(self, "cpu_offload") and self.cpu_offload:
attn_out = self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = fp16_clamp(x + attn_out)
del attn_out
ffn_out = self.ffn(self.norm2(x))
x = fp16_clamp(x + ffn_out)
del ffn_out
else:
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
......@@ -276,7 +394,9 @@ class T5RelativeEmbedding(nn.Module):
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
......@@ -300,9 +420,23 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module):
def __init__(self, dtype, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
def __init__(
self,
dtype,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1,
cpu_offload=False,
quantized=False,
quant_scheme=None,
):
super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
self.dim = dim
self.dim_attn = dim_attn
......@@ -317,53 +451,126 @@ class T5Encoder(nn.Module):
self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
if cpu_offload:
for block in self.blocks:
block.cpu_offload = cpu_offload
block.attn.cpu_offload = cpu_offload
block.ffn.cpu_offload = cpu_offload
self.norm = T5LayerNorm(dim, dtype=dtype)
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=num_layers)
self.blocks_weights = T5OffloadBlocksWeights(num_layers, quant_scheme)
self.blocks = self.blocks_weights.blocks
else:
self.blocks = nn.ModuleList(
[
T5SelfAttention(
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos,
dropout,
quantized,
quant_scheme,
dtype,
)
for _ in range(num_layers)
]
)
# initialize weights
# self.apply(init_weights)
self.norm = T5LayerNorm(dim, dtype=dtype)
def forward(self, ids, mask=None):
if self.cpu_offload:
self.token_embedding = self.token_embedding.cuda()
def forward_without_offload(self, ids, mask=None):
x = self.token_embedding(ids)
if self.cpu_offload:
self.token_embedding = self.token_embedding.cpu()
optimize_memory_usage()
x = self.dropout(x)
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cuda()
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cpu()
optimize_memory_usage()
for i, block in enumerate(self.blocks):
if self.cpu_offload:
block = block.cuda()
x = block(x, mask, pos_bias=e)
if self.cpu_offload:
block = block.cpu()
del block
optimize_memory_usage()
if self.cpu_offload:
self.norm = self.norm.cuda()
x = self.norm(x)
if self.cpu_offload:
self.norm = self.norm.cpu()
optimize_memory_usage()
x = self.dropout(x)
return x.to(GET_DTYPE())
def forword_attn_with_offload(self, x, attn_phase, context=None, mask=None, pos_bias=None):
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.dim_attn // self.num_heads
# compute query, key, value
q = attn_phase.attn_q.apply(x.squeeze(0)).view(b, -1, n, c)
k = attn_phase.attn_k.apply(context.squeeze(0)).view(b, -1, n, c)
v = attn_phase.attn_v.apply(context.squeeze(0)).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
x = x.reshape(b, -1, n * c)
x = attn_phase.attn_o.apply(x.squeeze(0)).unsqueeze(0)
return x
def forward_ffn_with_offload(self, x, ffn_phase):
x = x.squeeze(0)
x = ffn_phase.ffn_fc1.apply(x) * ffn_phase.gelu(ffn_phase.ffn_gate_0.apply(x))
x = ffn_phase.ffn_fc2.apply(x)
return x.unsqueeze(0)
def forward_block_with_offload(self, block, x, mask=None, pos_bias=None):
if self.shared_pos:
e = pos_bias
else:
lq, lk = x.size(1), x.size(1)
rel_pos = torch.arange(lk, device="cuda").unsqueeze(0) - torch.arange(lq, device="cuda").unsqueeze(1)
num_buckets = block.pos_embedding.weight.shape[0] // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(128 / max_exact) * (num_buckets - max_exact)).long()
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
e = block.pos_embedding.apply(rel_buckets).permute(2, 0, 1).unsqueeze(0).contiguous()
norm1_out = block.norm1.apply(x)
x = fp16_clamp(x + self.forword_attn_with_offload(norm1_out, block.compute_phases[0], mask=mask, pos_bias=e))
norm2_out = block.norm2.apply(x)
x = fp16_clamp(x + self.forward_ffn_with_offload(norm2_out, block.compute_phases[1]))
return x
def forward_with_offload(self, ids, mask=None):
self.token_embedding = self.token_embedding.to("cuda")
self.pos_embedding = self.pos_embedding.to("cuda") if self.pos_embedding is not None else None
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
self.norm = self.norm.to("cuda")
for block_idx in range(len(self.blocks)):
self.block_idx = block_idx
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = self.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < len(self.blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, self.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.forward_block_with_offload(self.blocks[block_idx], x, mask, pos_bias=e)
self.weights_stream_mgr.swap_weights()
x = self.norm(x)
x = self.dropout(x)
return x.to(GET_DTYPE())
def forward(self, ids, mask=None):
if self.cpu_offload:
return self.forward_with_offload(ids, mask)
else:
return self.forward_without_offload(ids, mask)
class T5Decoder(nn.Module):
def __init__(
......@@ -513,6 +720,15 @@ def _t5(
return model
def split_block_weights(weights):
block_weights = {}
all_keys = list(weights.keys())
for key in all_keys:
if key.startswith(("blocks.")):
block_weights[key] = weights.pop(key)
return block_weights
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
......@@ -540,7 +756,6 @@ class T5EncoderModel:
tokenizer_path=None,
shard_fn=None,
cpu_offload=False,
offload_granularity="model",
t5_quantized=False,
t5_quantized_ckpt=None,
quant_scheme=None,
......@@ -554,12 +769,9 @@ class T5EncoderModel:
else:
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
self.offload_granularity = offload_granularity
# sync cpu offload
self.cpu_offload = cpu_offload
if self.cpu_offload:
assert self.offload_granularity in ["block", "model"]
model = (
umt5_xxl(
......@@ -567,7 +779,7 @@ class T5EncoderModel:
return_tokenizer=False,
dtype=dtype,
device=device,
cpu_offload=(cpu_offload if self.offload_granularity == "block" else False),
cpu_offload=cpu_offload,
quantized=t5_quantized,
quant_scheme=quant_scheme,
)
......@@ -575,9 +787,21 @@ class T5EncoderModel:
.requires_grad_(False)
)
weights_dict = load_weights(self.checkpoint_path, cpu_offload=cpu_offload, load_from_rank0=load_from_rank0)
model.load_state_dict(weights_dict)
weights_dict = load_weights(
self.checkpoint_path,
cpu_offload=cpu_offload,
load_from_rank0=load_from_rank0,
)
if cpu_offload:
block_weights_dict = split_block_weights(weights_dict)
model.blocks_weights.load(block_weights_dict)
del block_weights_dict
gc.collect()
model.load_state_dict(weights_dict)
del weights_dict
gc.collect()
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
......@@ -586,20 +810,7 @@ class T5EncoderModel:
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
def optimize_memory(self):
"""优化内存使用"""
optimize_memory_usage()
def infer(self, texts):
if self.cpu_offload and self.offload_granularity == "model":
self.to_cuda()
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.cuda()
mask = mask.cuda()
......@@ -608,29 +819,39 @@ class T5EncoderModel:
with torch.no_grad():
context = self.model(ids, mask)
if self.cpu_offload and self.offload_granularity == "model":
self.to_cpu()
optimize_memory_usage()
del ids, mask
if self.cpu_offload:
optimize_memory_usage()
return [u[:v] for u, v in zip(context, seq_lens)]
if __name__ == "__main__":
import time
checkpoint_dir = ""
t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
t5_tokenizer = "google/umt5-xxl"
t5_checkpoint = "./models_t5_umt5-xxl-enc-bf16.pth"
t5_tokenizer = "./google/umt5-xxl"
cpu_offload = True
if cpu_offload:
device = torch.device("cpu")
else:
device = torch.device("cuda")
model = T5EncoderModel(
text_len=512,
dtype=torch.bfloat16,
device=torch.device("cuda"),
device=device,
checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer),
shard_fn=None,
cpu_offload=cpu_offload,
)
text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
torch.cuda.synchronize()
s_t = time.time()
outputs = model.infer(text)
torch.cuda.synchronize()
e_t = time.time()
logger.info(e_t - s_t)
logger.info(outputs)
......@@ -128,7 +128,6 @@ class WanRunner(DefaultRunner):
tokenizer_path=tokenizer_path,
shard_fn=None,
cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), # support ['model', 'block']
t5_quantized=t5_quantized,
t5_quantized_ckpt=t5_quantized_ckpt,
quant_scheme=t5_quant_scheme,
......
......@@ -52,5 +52,5 @@ CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register()
TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER = Register()
EMBEDDING_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment