Unverified Commit f4f85413 authored by ethansmith2000's avatar ethansmith2000 Committed by GitHub
Browse files

grad checkpointing (#4474)



* grad checkpointing

* fix make fix-copies

* fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e1b5b8ba
...@@ -648,16 +648,13 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -648,16 +648,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
return custom_forward return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), create_custom_forward(resnet),
...@@ -1035,16 +1032,13 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1035,16 +1032,13 @@ class CrossAttnDownBlock2D(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1711,13 +1705,12 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1711,13 +1705,12 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
mask, attention_mask=mask,
cross_attention_kwargs, **cross_attention_kwargs,
)[0] )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1912,15 +1905,13 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1912,15 +1905,13 @@ class KCrossAttnDownBlock2D(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb, emb=temb,
attention_mask, attention_mask=attention_mask,
cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
**ckpt_kwargs,
) )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2173,16 +2164,13 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2173,16 +2164,13 @@ class CrossAttnUpBlock2D(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -2872,13 +2860,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2872,13 +2860,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
mask, attention_mask=mask,
cross_attention_kwargs, **cross_attention_kwargs,
)[0] )
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -3094,16 +3081,14 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3094,16 +3081,14 @@ class KCrossAttnUpBlock2D(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb, emb=temb,
attention_mask, attention_mask=attention_mask,
cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
**ckpt_kwargs, )
)[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
......
...@@ -1429,16 +1429,13 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1429,16 +1429,13 @@ class CrossAttnDownBlockFlat(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1668,16 +1665,13 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1668,16 +1665,13 @@ class CrossAttnUpBlockFlat(nn.Module):
temb, temb,
**ckpt_kwargs, **ckpt_kwargs,
) )
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -1809,16 +1803,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1809,16 +1803,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
return custom_forward return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = attn(
create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
None, # timestep cross_attention_kwargs=cross_attention_kwargs,
None, # class_labels attention_mask=attention_mask,
cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask,
attention_mask, return_dict=False,
encoder_attention_mask,
**ckpt_kwargs,
)[0] )[0]
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), create_custom_forward(resnet),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment