Commit a395cc0a authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix offload bugs and support wan2.2_vae offload

Fix offload bugs and support wan2.2_vae offload
parents b723fc89 aa88b371
...@@ -37,6 +37,9 @@ class WanModel: ...@@ -37,6 +37,9 @@ class WanModel:
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
...@@ -202,15 +205,18 @@ class WanModel: ...@@ -202,15 +205,18 @@ class WanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None: if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape _, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2) video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c) self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
...@@ -228,14 +234,17 @@ class WanModel: ...@@ -228,14 +234,17 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config.get("cpu_offload", False): if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
class Wan22MoeModel(WanModel): class Wan22MoeModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16): def _load_ckpt(self, use_bf16, skip_bf16):
...@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel): ...@@ -248,6 +257,10 @@ class Wan22MoeModel(WanModel):
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
...@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel): ...@@ -260,3 +273,7 @@ class Wan22MoeModel(WanModel):
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.cpu_offload and self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
...@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner): ...@@ -418,5 +418,5 @@ class Wan22DenseRunner(WanRunner):
return vae_encoder_out return vae_encoder_out
def get_vae_encoder_output(self, img): def get_vae_encoder_output(self, img):
z = self.vae_encoder.encode(img) z = self.vae_encoder.encode(img, self.config)
return z return z
...@@ -844,7 +844,7 @@ class Wan2_2_VAE: ...@@ -844,7 +844,7 @@ class Wan2_2_VAE:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
mean = torch.tensor( self.mean = torch.tensor(
[ [
-0.2289, -0.2289,
-0.0052, -0.0052,
...@@ -898,7 +898,7 @@ class Wan2_2_VAE: ...@@ -898,7 +898,7 @@ class Wan2_2_VAE:
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
std = torch.tensor( self.std = torch.tensor(
[ [
0.4765, 0.4765,
1.0364, 1.0364,
...@@ -952,8 +952,8 @@ class Wan2_2_VAE: ...@@ -952,8 +952,8 @@ class Wan2_2_VAE:
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
self.scale = [mean, 1.0 / std] self.inv_std = 1.0 / self.std
self.scale = [self.mean, self.inv_std]
# init model # init model
self.model = ( self.model = (
_video_vae( _video_vae(
...@@ -968,25 +968,35 @@ class Wan2_2_VAE: ...@@ -968,25 +968,35 @@ class Wan2_2_VAE:
.to(device) .to(device)
) )
def encode(self, videos): def to_cpu(self):
# try: self.model.encoder = self.model.encoder.to("cpu")
# if not isinstance(videos, list): self.model.decoder = self.model.decoder.to("cpu")
# raise TypeError("videos should be a list") self.model = self.model.to("cpu")
# with amp.autocast(dtype=self.dtype): self.mean = self.mean.cpu()
# return [ self.inv_std = self.inv_std.cpu()
# self.model.encode(u.unsqueeze(0), self.scale = [self.mean, self.inv_std]
# self.scale).float().squeeze(0)
# for u in videos def to_cuda(self):
# ] self.model.encoder = self.model.encoder.to("cuda")
# except TypeError as e: self.model.decoder = self.model.decoder.to("cuda")
# logging.info(e) self.model = self.model.to("cuda")
# return None self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
# print(1111111) self.scale = [self.mean, self.inv_std]
# print(self.model.encode(videos.unsqueeze(0), self.scale).float().shape)
# exit() def encode(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
return self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0) self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
if hasattr(args, "cpu_offload") and args.cpu_offload:
self.to_cpu()
return out
def decode(self, zs, generator, config): def decode(self, zs, generator, config):
return self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1) if config.cpu_offload:
self.to_cuda()
images = self.model.decode(zs.unsqueeze(0), self.scale).float().clamp_(-1, 1)
if config.cpu_offload:
images = images.cpu().float()
self.to_cpu()
return images
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment