Unverified Commit bec2d8ea authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Fix: Add _skip_keys for AutoencoderKLWan (#12523)

add
parent a0a51eb0
...@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module): ...@@ -453,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block # First residual block
x = self.resnets[0](x, feat_cache, feat_idx) x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks # Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
x = attn(x) x = attn(x)
x = resnet(x, feat_cache, feat_idx) x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x return x
...@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module): ...@@ -494,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone() x_copy = x.clone()
for resnet in self.resnets: for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx) x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
if self.downsampler is not None: if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx) x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x + self.avg_shortcut(x_copy) return x + self.avg_shortcut(x_copy)
...@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module): ...@@ -598,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples ## downsamples
for layer in self.down_blocks: for layer in self.down_blocks:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else: else:
x = layer(x) x = layer(x)
## middle ## middle
x = self.mid_block(x, feat_cache, feat_idx) x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head ## head
x = self.norm_out(x) x = self.norm_out(x)
...@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module): ...@@ -694,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets: for resnet in self.resnets:
if feat_cache is not None: if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx) x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else: else:
x = resnet(x) x = resnet(x)
if self.upsampler is not None: if self.upsampler is not None:
if feat_cache is not None: if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx) x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
else: else:
x = self.upsampler(x) x = self.upsampler(x)
...@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module): ...@@ -767,13 +767,13 @@ class WanUpBlock(nn.Module):
""" """
for resnet in self.resnets: for resnet in self.resnets:
if feat_cache is not None: if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx) x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else: else:
x = resnet(x) x = resnet(x)
if self.upsamplers is not None: if self.upsamplers is not None:
if feat_cache is not None: if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx) x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else: else:
x = self.upsamplers[0](x) x = self.upsamplers[0](x)
return x return x
...@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module): ...@@ -885,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x) x = self.conv_in(x)
## middle ## middle
x = self.mid_block(x, feat_cache, feat_idx) x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples ## upsamples
for up_block in self.up_blocks: for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head ## head
x = self.norm_out(x) x = self.norm_out(x)
...@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo ...@@ -961,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
""" """
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
# these are shared mutable state modified in-place
_skip_keys = ["feat_cache", "feat_idx"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -251,6 +251,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_repeated_blocks = [] _repeated_blocks = []
_parallel_config = None _parallel_config = None
_cp_plan = None _cp_plan = None
_skip_keys = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -866,6 +866,9 @@ def load_sub_model( ...@@ -866,6 +866,9 @@ def load_sub_model(
# remove hooks # remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True) remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu" needs_offloading_to_cpu = device_map[""] == "cpu"
skip_keys = None
if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu: if needs_offloading_to_cpu:
dispatch_model( dispatch_model(
...@@ -874,9 +877,10 @@ def load_sub_model( ...@@ -874,9 +877,10 @@ def load_sub_model(
device_map=device_map, device_map=device_map,
force_hooks=True, force_hooks=True,
main_device=0, main_device=0,
skip_keys=skip_keys,
) )
else: else:
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
return loaded_sub_model return loaded_sub_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment