"tasks/dialctrl/data.py" did not exist on "3c363d5709143f1e3a210f9e54bd80a7356d9e14"
Commit 0b755a97 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix offload bug in new dist infer

Fix offload bug in new dist infer
parents 88433448 ff66b814
......@@ -18,7 +18,7 @@ class WanAudioPostInfer(WanPostInfer):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out):
def infer(self, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length]
x = self.unpatchify(x, pre_infer_out.grid_sizes)
......
......@@ -15,7 +15,7 @@ class WanPostInfer:
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, pre_infer_out):
def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
......
......@@ -39,8 +39,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.seq_p_group = None
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
# if torch.cuda.get_device_capability(0) == (9, 0):
# assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
......
......@@ -225,12 +225,10 @@ class WanModel:
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
def _load_weights_distribute(self, weight_dict, is_weight_loader):
......@@ -303,12 +301,10 @@ class WanModel:
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
......@@ -318,7 +314,7 @@ class WanModel:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
......@@ -356,7 +352,7 @@ class WanModel:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.post_weights_to_cpu()
@torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True):
......@@ -370,7 +366,7 @@ class WanModel:
if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)
noise_pred = self.post_infer.infer(self.post_weight, x, pre_infer_out)[0]
noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
if self.clean_cuda_cache:
del x, pre_infer_out
......
......@@ -36,6 +36,16 @@ class WanTransformerWeights(WeightModule):
for phase in block.compute_phases:
phase.clear()
def post_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
self.head_modulation.to_cuda()
def post_weights_to_cpu(self):
self.norm.to_cpu()
self.head.to_cpu()
self.head_modulation.to_cpu()
class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment