Unverified Commit fc188391 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

class labels timestep embeddings projection dtype cast (#3137)

This mimics the dtype cast for the standard time embeddings
parent f0c74e9a
...@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors # `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
...@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.config.class_embed_type == "timestep": if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels) class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat: if self.config.class_embeddings_concat:
......
...@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors # `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
...@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if self.config.class_embed_type == "timestep": if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels) class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
if self.config.class_embeddings_concat: if self.config.class_embeddings_concat:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment