Unverified Commit fc188391 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

class labels timestep embeddings projection dtype cast (#3137)

This mimics the dtype cast for the standard time embeddings
parent f0c74e9a
......@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
......@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat:
......
......@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
......@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment