Commit a395cc0a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix offload bugs and support wan2.2_vae offload

Fix offload bugs and support wan2.2_vae offload
parents b723fc89 aa88b371
......@@ -37,6 +37,9 @@ class WanModel:
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
......@@ -202,15 +205,18 @@ class WanModel:
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......@@ -228,14 +234,17 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config.get("cpu_offload", False):
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16):
......@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
......@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
return vae_encoder_out
def get_vae_encoder_output(self, img):
z = self.vae_encoder.encode(img)
z = self.vae_encoder.encode(img, self.config)
return z
......@@ -844,7 +844,7 @@ class Wan2_2_VAE:
self.dtype = dtype
self.device = device
mean = torch.tensor(
self.mean = torch.tensor(
[
-0.2289,
-0.0052,
......@@ -898,7 +898,7 @@ class Wan2_2_VAE:
dtype=dtype,
device=device,
)
std = torch.tensor(
self.std = torch.tensor(
[
0.4765,
1.0364,
......@@ -952,8 +952,8 @@ class Wan2_2_VAE:
dtype=dtype,
device=device,
)
self.scale = [mean, 1.0 / std]
self.inv_std = 1.0 / self.std
self.scale = [self.mean, self.inv_std]
# init model
self.model = (
_video_vae(
......@@ -968,25 +968,35 @@ class Wan2_2_VAE:
.to(device)
)
def encode(self, videos):
# try:
# if not isinstance(videos, list):
# raise TypeError("videos should be a list")
# with amp.autocast(dtype=self.dtype):
# return [
# self.model.encode(u.unsqueeze(0),
# self.scale).float().squeeze(0)
# for u in videos
# ]
# except TypeError as e:
# logging.info(e)
# return None
# print(1111111)
# print(self.model.encode(videos.unsqueeze(0), self.scale).float().shape)
# exit()
return self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
def to_cpu(self):
self.model.encoder = self.model.encoder.to("cpu")
self.model.decoder = self.model.decoder.to("cpu")
self.model = self.model.to("cpu")
self.mean = self.mean.cpu()
self.inv_std = self.inv_std.cpu()
self.scale = [self.mean, self.inv_std]
def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda")
self.model.decoder = self.model.decoder.to("cuda")
self.model = self.model.to("cuda")
self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu()
return out
def decode(self, zs, generator, config):
return self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload:
self.to_cuda()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload:
images = images.cpu().float()
self.to_cpu()
return images
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment