Commit 0b755a97 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix offload bug in new dist infer

Fix offload bug in new dist infer
parents 88433448 ff66b814
...@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length] x = x[:, : pre_infer_out.valid_patch_length]
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes)
......
...@@ -15,7 +15,7 @@ class WanPostInfer: ...@@ -15,7 +15,7 @@ class WanPostInfer:
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
......
...@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.seq_p_group = None self.seq_p_group = None
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0): # if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2" # assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
else: else:
......
...@@ -225,12 +225,10 @@ class WanModel: ...@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers # Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
# Load weights into containers # Load weights into containers
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
def _load_weights_distribute(self, weight_dict, is_weight_loader): def _load_weights_distribute(self, weight_dict, is_weight_loader):
...@@ -303,12 +301,10 @@ class WanModel: ...@@ -303,12 +301,10 @@ class WanModel:
def to_cpu(self): def to_cpu(self):
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu() self.transformer_weights.to_cpu()
def to_cuda(self): def to_cuda(self):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda() self.transformer_weights.to_cuda()
@torch.no_grad() @torch.no_grad()
...@@ -318,7 +314,7 @@ class WanModel: ...@@ -318,7 +314,7 @@ class WanModel:
self.to_cuda() self.to_cuda()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.transformer_weights.post_weights_to_cuda()
if self.transformer_infer.mask_map is None: if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape _, c, h, w = self.scheduler.latents.shape
...@@ -356,7 +352,7 @@ class WanModel: ...@@ -356,7 +352,7 @@ class WanModel:
self.to_cpu() self.to_cpu()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.transformer_weights.post_weights_to_cpu()
@torch.no_grad() @torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True): def _infer_cond_uncond(self, inputs, positive=True):
...@@ -370,7 +366,7 @@ class WanModel: ...@@ -370,7 +366,7 @@ class WanModel:
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x) x = self._seq_parallel_post_process(x)
noise_pred = self.post_infer.infer(self.post_weight, x, pre_infer_out)[0] noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
if self.clean_cuda_cache: if self.clean_cuda_cache:
del x, pre_infer_out del x, pre_infer_out
......
...@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule): ...@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for phase in block.compute_phases: for phase in block.compute_phases:
phase.clear() phase.clear()
def post_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
self.head_modulation.to_cuda()
def post_weights_to_cpu(self):
self.norm.to_cpu()
self.head.to_cpu()
self.head_modulation.to_cpu()
class WanTransformerAttentionBlock(WeightModule): class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config): def __init__(self, block_index, task, mm_type, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment