Unverified Commit cb1b8b21 authored by Cheng Jin's avatar Cheng Jin Committed by GitHub
Browse files

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP
parent 27916822
......@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = self.conv_shortcut(input_tensor.contiguous())
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment