"src/vscode:/vscode.git/clone" did not exist on "988c82227db1a41846a9aae5c83750dcfc334f66"
Unverified Commit 5b972fbd authored by Michael Tkachuk's avatar Michael Tkachuk Committed by GitHub
Browse files

Enabling gradient checkpointing in eval() mode (#9878)

* refactored
parent 0be52c07
...@@ -350,7 +350,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -350,7 +350,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
) )
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -317,7 +317,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -317,7 +317,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module): ...@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
# 2. Blocks # 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
block, block,
hidden_states, hidden_states,
......
...@@ -859,7 +859,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -859,7 +859,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1257,7 +1257,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1257,7 +1257,7 @@ class CrossAttnDownBlock2D(nn.Module):
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1371,7 +1371,7 @@ class DownBlock2D(nn.Module): ...@@ -1371,7 +1371,7 @@ class DownBlock2D(nn.Module):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1859,7 +1859,7 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1859,7 +1859,7 @@ class ResnetDownsampleBlock2D(nn.Module):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2011,7 +2011,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -2011,7 +2011,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
mask = attention_mask mask = attention_mask
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2106,7 +2106,7 @@ class KDownBlock2D(nn.Module): ...@@ -2106,7 +2106,7 @@ class KDownBlock2D(nn.Module):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2215,7 +2215,7 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -2215,7 +2215,7 @@ class KCrossAttnDownBlock2D(nn.Module):
output_states = () output_states = ()
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2520,7 +2520,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2520,7 +2520,7 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2653,7 +2653,7 @@ class UpBlock2D(nn.Module): ...@@ -2653,7 +2653,7 @@ class UpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -3183,7 +3183,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -3183,7 +3183,7 @@ class ResnetUpsampleBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -3341,7 +3341,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -3341,7 +3341,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -3444,7 +3444,7 @@ class KUpBlock2D(nn.Module): ...@@ -3444,7 +3444,7 @@ class KUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -3572,7 +3572,7 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3572,7 +3572,7 @@ class KCrossAttnUpBlock2D(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module): ...@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
) )
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module): ...@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module): ...@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks: for resnet, attn in blocks:
if self.training and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1383,7 +1383,7 @@ class UpBlockSpatioTemporal(nn.Module): ...@@ -1383,7 +1383,7 @@ class UpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1493,7 +1493,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module): ...@@ -1493,7 +1493,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: # TODO if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module): ...@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module):
blocks = zip(self.resnets, self.motion_modules) blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks: for resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module):
blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks): for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module): ...@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
level_outputs = [] level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x = level_outputs[0] x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states = self.project_to_hidden(hidden_states) hidden_states = self.project_to_hidden(hidden_states)
for layer in self.transformer_layers: for layer in self.transformer_layers:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def layer_(*args): def layer_(*args):
return checkpoint(layer, *args) return checkpoint(layer, *args)
......
...@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
) )
for i in range(num_layers): for i in range(num_layers):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
) )
for i in range(len(self.resnets[1:])): for i in range(len(self.resnets[1:])):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module): ...@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
......
...@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module): ...@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module): ...@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2352,7 +2352,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2352,7 +2352,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -590,7 +590,7 @@ class GLMTransformer(torch.nn.Module): ...@@ -590,7 +590,7 @@ class GLMTransformer(torch.nn.Module):
if not kv_caches: if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)] kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None presents = () if use_cache else None
if self.gradient_checkpointing and self.training: if torch.is_grad_enabled() and self.gradient_checkpointing:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
...@@ -604,7 +604,7 @@ class GLMTransformer(torch.nn.Module): ...@@ -604,7 +604,7 @@ class GLMTransformer(torch.nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer = self._get_layer(index) layer = self._get_layer(index)
if self.gradient_checkpointing and self.training: if torch.is_grad_enabled() and self.gradient_checkpointing:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = torch.utils.checkpoint.checkpoint(
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
) )
......
...@@ -675,7 +675,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel): ...@@ -675,7 +675,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
c_embed = self.cond_mapper(c) c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r) r_embed = self.gen_r_embedding(r)
if self.training and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment