Commit 92539ed8 authored by gushiqiao's avatar gushiqiao
Browse files

Update gradio and offload

parent 8e941d39
......@@ -59,8 +59,6 @@ class QuantLinearFp8(nn.Module):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
......
......@@ -3,6 +3,8 @@
import logging
import math
import os
from six import b
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -27,6 +29,14 @@ def fp16_clamp(x):
return x
def optimize_memory_usage():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
......@@ -114,10 +124,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
......@@ -144,7 +158,14 @@ class T5FeedForward(nn.Module):
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
if hasattr(self, "cpu_offload") and self.cpu_offload:
gate_out = self.gate(x)
fc1_out = self.fc1(x)
x = fc1_out * gate_out
del gate_out, fc1_out
else:
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
......@@ -170,8 +191,19 @@ class T5SelfAttention(nn.Module):
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
if hasattr(self, "cpu_offload") and self.cpu_offload:
attn_out = self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = fp16_clamp(x + attn_out)
del attn_out
ffn_out = self.ffn(self.norm2(x))
x = fp16_clamp(x + ffn_out)
del ffn_out
else:
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
......@@ -270,6 +302,12 @@ class T5Encoder(nn.Module):
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
if cpu_offload:
for block in self.blocks:
block.cpu_offload = cpu_offload
block.attn.cpu_offload = cpu_offload
block.ffn.cpu_offload = cpu_offload
self.norm = T5LayerNorm(dim, dtype=dtype)
# initialize weights
......@@ -281,23 +319,32 @@ class T5Encoder(nn.Module):
x = self.token_embedding(ids)
if self.cpu_offload:
self.token_embedding = self.token_embedding.cpu()
optimize_memory_usage()
x = self.dropout(x)
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cuda()
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cpu()
for block in self.blocks:
optimize_memory_usage()
for i, block in enumerate(self.blocks):
if self.cpu_offload:
block = block.cuda()
x = block(x, mask, pos_bias=e)
if self.cpu_offload:
block = block.cpu()
del block
optimize_memory_usage()
if self.cpu_offload:
self.norm = self.norm.cuda()
x = self.norm(x)
if self.cpu_offload:
self.norm = self.norm.cpu()
optimize_memory_usage()
x = self.dropout(x)
return x.to(torch.bfloat16)
......@@ -529,6 +576,10 @@ class T5EncoderModel:
def to_cuda(self):
self.model = self.model.to("cuda")
def optimize_memory(self):
"""优化内存使用"""
optimize_memory_usage()
def infer(self, texts):
if self.cpu_offload and self.offload_granularity == "model":
self.to_cuda()
......@@ -537,10 +588,17 @@ class T5EncoderModel:
ids = ids.cuda()
mask = mask.cuda()
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
with torch.no_grad():
context = self.model(ids, mask)
if self.cpu_offload and self.offload_granularity == "model":
self.to_cpu()
optimize_memory_usage()
del ids, mask
if self.cpu_offload:
optimize_memory_usage()
return [u[:v] for u, v in zip(context, seq_lens)]
......
......@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks)
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config):
......
......@@ -49,7 +49,7 @@ class DefaultRunner:
else:
self.run_input_encoder = self.run_input_encoder_server_t2v
else:
if not self.config.get("lazy_load", False):
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model()
self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local
......@@ -136,8 +136,13 @@ class DefaultRunner:
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False):
self.model.transformer_infer.weights_stream_mgr.clear()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
self.model.post_weight.clear()
del self.model
torch.cuda.empty_cache()
gc.collect()
......@@ -163,7 +168,7 @@ class DefaultRunner:
@ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs):
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
......@@ -173,10 +178,10 @@ class DefaultRunner:
@ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
......
......@@ -61,14 +61,19 @@ class WanRunner(DefaultRunner):
return image_encoder
def load_text_encoder(self):
t5_offload = self.config.get("t5_cpu_offload", False)
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device("cuda")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=self.init_device,
device=t5_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
cpu_offload=self.config.cpu_offload,
cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"),
t5_quantized=self.config.get("t5_quantized", False),
t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None),
......@@ -129,13 +134,13 @@ class WanRunner(DefaultRunner):
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0]
torch.cuda.empty_cache()
gc.collect()
......@@ -144,11 +149,11 @@ class WanRunner(DefaultRunner):
return text_encoder_output
def run_image_encoder(self, img):
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
......@@ -179,7 +184,7 @@ class WanRunner(DefaultRunner):
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
vae_encode_out = self.vae_encoder.encode(
[
......@@ -193,7 +198,7 @@ class WanRunner(DefaultRunner):
],
self.config,
)[0]
if self.config.get("lazy_load", False):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment