Commit 0b755a97 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix offload bug in new dist infer

Fix offload bug in new dist infer
parents 88433448 ff66b814
......@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out):
def infer(self, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length]
x = self.unpatchify(x, pre_infer_out.grid_sizes)
......
......@@ -15,7 +15,7 @@ class WanPostInfer:
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out):
def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
......
......@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.seq_p_group = None
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
# if torch.cuda.get_device_capability(0) == (9, 0):
# assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
......
......@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
def _load_weights_distribute(self, weight_dict, is_weight_loader):
......@@ -303,12 +301,10 @@ class WanModel:
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
......@@ -318,7 +314,7 @@ class WanModel:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
......@@ -356,7 +352,7 @@ class WanModel:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.post_weights_to_cpu()
@torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True):
......@@ -370,7 +366,7 @@ class WanModel:
if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)
noise_pred = self.post_infer.infer(self.post_weight, x, pre_infer_out)[0]
noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
if self.clean_cuda_cache:
del x, pre_infer_out
......
......@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for phase in block.compute_phases:
phase.clear()
def post_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
self.head_modulation.to_cuda()
def post_weights_to_cpu(self):
self.norm.to_cpu()
self.head.to_cpu()
self.head_modulation.to_cpu()
class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment