Unverified Commit 56c00370 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Use `expand` instead of ones to broadcast tensor (#373)

Use `expand` instead of ones to broadcast tensor.

As suggested by @bes-dev. According the documentation this shouldn't
take any memory - it just plays with the strides.
parent 7a1229fa
...@@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment