Unverified Commit 5366db5d authored by kesimeg's avatar kesimeg Committed by GitHub
Browse files

fix une2td ignoring class_labels (#5401)



* fix une2td ignoring class_labels

* unet2.py error message updated

* style and quality changes

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent e5168588
...@@ -291,6 +291,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -291,6 +291,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process # 2. pre-process
skip_sample = sample skip_sample = sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment