Commit 92539ed8 authored by gushiqiao's avatar gushiqiao
Browse files

Update gradio and offload

parent 8e941d39
...@@ -59,8 +59,6 @@ class QuantLinearFp8(nn.Module): ...@@ -59,8 +59,6 @@ class QuantLinearFp8(nn.Module):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias: if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import logging import logging
import math import math
import os import os
from six import b
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -27,6 +29,14 @@ def fp16_clamp(x): ...@@ -27,6 +29,14 @@ def fp16_clamp(x):
return x return x
def optimize_memory_usage():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
def init_weights(m): def init_weights(m):
if isinstance(m, T5LayerNorm): if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
...@@ -114,10 +124,14 @@ class T5Attention(nn.Module): ...@@ -114,10 +124,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling) # compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn_bias
attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16) attn = F.softmax(attn.float(), dim=-1).to(torch.bfloat16)
x = torch.einsum("bnij,bjnc->binc", attn, v) x = torch.einsum("bnij,bjnc->binc", attn, v)
# output if hasattr(self, "cpu_offload") and self.cpu_offload:
del attn
x = x.reshape(b, -1, n * c) x = x.reshape(b, -1, n * c)
x = self.o(x) x = self.o(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -144,7 +158,14 @@ class T5FeedForward(nn.Module): ...@@ -144,7 +158,14 @@ class T5FeedForward(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
if hasattr(self, "cpu_offload") and self.cpu_offload:
gate_out = self.gate(x)
fc1_out = self.fc1(x)
x = fc1_out * gate_out
del gate_out, fc1_out
else:
x = self.fc1(x) * self.gate(x) x = self.fc1(x) * self.gate(x)
x = self.dropout(x) x = self.dropout(x)
x = self.fc2(x) x = self.fc2(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -170,8 +191,19 @@ class T5SelfAttention(nn.Module): ...@@ -170,8 +191,19 @@ class T5SelfAttention(nn.Module):
def forward(self, x, mask=None, pos_bias=None): def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
if hasattr(self, "cpu_offload") and self.cpu_offload:
attn_out = self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = fp16_clamp(x + attn_out)
del attn_out
ffn_out = self.ffn(self.norm2(x))
x = fp16_clamp(x + ffn_out)
del ffn_out
else:
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x))) x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x return x
...@@ -270,6 +302,12 @@ class T5Encoder(nn.Module): ...@@ -270,6 +302,12 @@ class T5Encoder(nn.Module):
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)]) self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
if cpu_offload:
for block in self.blocks:
block.cpu_offload = cpu_offload
block.attn.cpu_offload = cpu_offload
block.ffn.cpu_offload = cpu_offload
self.norm = T5LayerNorm(dim, dtype=dtype) self.norm = T5LayerNorm(dim, dtype=dtype)
# initialize weights # initialize weights
...@@ -281,23 +319,32 @@ class T5Encoder(nn.Module): ...@@ -281,23 +319,32 @@ class T5Encoder(nn.Module):
x = self.token_embedding(ids) x = self.token_embedding(ids)
if self.cpu_offload: if self.cpu_offload:
self.token_embedding = self.token_embedding.cpu() self.token_embedding = self.token_embedding.cpu()
optimize_memory_usage()
x = self.dropout(x) x = self.dropout(x)
if self.cpu_offload and self.pos_embedding is not None: if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cuda() self.pos_embedding = self.pos_embedding.cuda()
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
if self.cpu_offload and self.pos_embedding is not None: if self.cpu_offload and self.pos_embedding is not None:
self.pos_embedding = self.pos_embedding.cpu() self.pos_embedding = self.pos_embedding.cpu()
for block in self.blocks: optimize_memory_usage()
for i, block in enumerate(self.blocks):
if self.cpu_offload: if self.cpu_offload:
block = block.cuda() block = block.cuda()
x = block(x, mask, pos_bias=e) x = block(x, mask, pos_bias=e)
if self.cpu_offload: if self.cpu_offload:
block = block.cpu() block = block.cpu()
del block
optimize_memory_usage()
if self.cpu_offload: if self.cpu_offload:
self.norm = self.norm.cuda() self.norm = self.norm.cuda()
x = self.norm(x) x = self.norm(x)
if self.cpu_offload: if self.cpu_offload:
self.norm = self.norm.cpu() self.norm = self.norm.cpu()
optimize_memory_usage()
x = self.dropout(x) x = self.dropout(x)
return x.to(torch.bfloat16) return x.to(torch.bfloat16)
...@@ -529,6 +576,10 @@ class T5EncoderModel: ...@@ -529,6 +576,10 @@ class T5EncoderModel:
def to_cuda(self): def to_cuda(self):
self.model = self.model.to("cuda") self.model = self.model.to("cuda")
def optimize_memory(self):
"""优化内存使用"""
optimize_memory_usage()
def infer(self, texts): def infer(self, texts):
if self.cpu_offload and self.offload_granularity == "model": if self.cpu_offload and self.offload_granularity == "model":
self.to_cuda() self.to_cuda()
...@@ -537,10 +588,17 @@ class T5EncoderModel: ...@@ -537,10 +588,17 @@ class T5EncoderModel:
ids = ids.cuda() ids = ids.cuda()
mask = mask.cuda() mask = mask.cuda()
seq_lens = mask.gt(0).sum(dim=1).long() seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad():
context = self.model(ids, mask) context = self.model(ids, mask)
if self.cpu_offload and self.offload_granularity == "model": if self.cpu_offload and self.offload_granularity == "model":
self.to_cpu() self.to_cpu()
optimize_memory_usage()
del ids, mask
if self.cpu_offload:
optimize_memory_usage()
return [u[:v] for u, v in zip(context, seq_lens)] return [u[:v] for u, v in zip(context, seq_lens)]
......
...@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule): ...@@ -24,6 +24,11 @@ class WanTransformerWeights(WeightModule):
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
class WanTransformerAttentionBlock(WeightModule): class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config):
......
...@@ -49,7 +49,7 @@ class DefaultRunner: ...@@ -49,7 +49,7 @@ class DefaultRunner:
else: else:
self.run_input_encoder = self.run_input_encoder_server_t2v self.run_input_encoder = self.run_input_encoder_server_t2v
else: else:
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model() self.load_model()
self.run_dit = self.run_dit_local self.run_dit = self.run_dit_local
self.run_vae_decoder = self.run_vae_decoder_local self.run_vae_decoder = self.run_vae_decoder_local
...@@ -136,8 +136,13 @@ class DefaultRunner: ...@@ -136,8 +136,13 @@ class DefaultRunner:
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs, self.model.scheduler del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear() self.model.transformer_infer.weights_stream_mgr.clear()
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
self.model.post_weight.clear()
del self.model del self.model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -163,7 +168,7 @@ class DefaultRunner: ...@@ -163,7 +168,7 @@ class DefaultRunner:
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
async def run_dit_local(self, kwargs): async def run_dit_local(self, kwargs):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
...@@ -173,10 +178,10 @@ class DefaultRunner: ...@@ -173,10 +178,10 @@ class DefaultRunner:
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator): async def run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
...@@ -61,14 +61,19 @@ class WanRunner(DefaultRunner): ...@@ -61,14 +61,19 @@ class WanRunner(DefaultRunner):
return image_encoder return image_encoder
def load_text_encoder(self): def load_text_encoder(self):
t5_offload = self.config.get("t5_cpu_offload", False)
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device("cuda")
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=self.init_device, device=t5_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"), tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None, shard_fn=None,
cpu_offload=self.config.cpu_offload, cpu_offload=t5_offload,
offload_granularity=self.config.get("t5_offload_granularity", "model"), offload_granularity=self.config.get("t5_offload_granularity", "model"),
t5_quantized=self.config.get("t5_quantized", False), t5_quantized=self.config.get("t5_quantized", False),
t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None), t5_quantized_ckpt=self.config.get("t5_quantized_ckpt", None),
...@@ -129,13 +134,13 @@ class WanRunner(DefaultRunner): ...@@ -129,13 +134,13 @@ class WanRunner(DefaultRunner):
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img): def run_text_encoder(self, text, img):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
text_encoder_output = {} text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "") n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([text])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""]) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.text_encoders[0] del self.text_encoders[0]
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -144,11 +149,11 @@ class WanRunner(DefaultRunner): ...@@ -144,11 +149,11 @@ class WanRunner(DefaultRunner):
return text_encoder_output return text_encoder_output
def run_image_encoder(self, img): def run_image_encoder(self, img):
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16) clip_encoder_out = self.image_encoder.visual([img[:, None, :, :]], self.config).squeeze(0).to(torch.bfloat16)
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -179,7 +184,7 @@ class WanRunner(DefaultRunner): ...@@ -179,7 +184,7 @@ class WanRunner(DefaultRunner):
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
vae_encode_out = self.vae_encoder.encode( vae_encode_out = self.vae_encoder.encode(
[ [
...@@ -193,7 +198,7 @@ class WanRunner(DefaultRunner): ...@@ -193,7 +198,7 @@ class WanRunner(DefaultRunner):
], ],
self.config, self.config,
)[0] )[0]
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment