"tests/python/common/test_heterograph.py" did not exist on "51651ecadc70cb4b254881e1211a92dea9174cdb"
Unverified Commit c6f8c310 authored by 小咩Goat's avatar 小咩Goat Committed by GitHub
Browse files

Fix forward pass in UNetMotionModel when gradient checkpoint is enabled (#6744)



fix #6742
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 64909f17
...@@ -1031,12 +1031,6 @@ class DownBlockMotion(nn.Module): ...@@ -1031,12 +1031,6 @@ class DownBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale create_custom_forward(resnet), hidden_states, temb, scale
) )
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb, scale=scale)
...@@ -1563,11 +1557,6 @@ class UpBlockMotion(nn.Module): ...@@ -1563,11 +1557,6 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
)
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb, scale=scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment