Unverified Commit 1f196a09 authored by Juan Carrasquilla's avatar Juan Carrasquilla Committed by GitHub
Browse files

Changed variable name from "h" to "hidden_states" (#285)



* Changed variable name from "h" to "hidden_states"

Per issue #198 , changed variable name from "h" to "hidden_states" in the forward function only. I am happy to change any other variable names, please advise recommended new names.

* Update src/diffusers/models/resnet.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 034673bb
......@@ -328,39 +328,39 @@ class ResnetBlock2D(nn.Module):
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, hey=False):
h = x
def forward(self, x, temb):
hidden_states = x
# make sure hidden states is in float32
# when running in half-precision
h = self.norm1(h.float()).type(h.dtype)
h = self.nonlinearity(h)
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
hidden_states = self.downsample(hidden_states)
h = self.conv1(h)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb
hidden_states = hidden_states + temb
# make sure hidden states is in float32
# when running in half-precision
h = self.norm2(h.float()).type(h.dtype)
h = self.nonlinearity(h)
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
hidden_states = self.nonlinearity(hidden_states)
h = self.dropout(h)
h = self.conv2(h)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + h) / self.output_scale_factor
out = (x + hidden_states) / self.output_scale_factor
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment