Commit eef2327a authored by patil-suraj's avatar patil-suraj
Browse files

update input names

parent 7dc71897
...@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t): def forward(self, x, timesteps):
assert x.shape[2] == x.shape[3] == self.resolution assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(t): if not torch.is_tensor(timesteps):
t = torch.tensor([t], dtype=torch.long, device=x.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding # timestep embedding
temb = get_timestep_embedding(t, self.ch) temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb) temb = self.temb.dense[0](temb)
temb = nonlinearity(temb) temb = nonlinearity(temb)
temb = self.temb.dense[1](temb) temb = self.temb.dense[1](temb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment