Unverified Commit c6f8c310 authored by 小咩Goat's avatar 小咩Goat Committed by GitHub
Browse files

Fix forward pass in UNetMotionModel when gradient checkpoint is enabled (#6744)



fix #6742
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 64909f17
......@@ -1031,12 +1031,6 @@ class DownBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
......@@ -1563,11 +1557,6 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment