"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "cd7b885169b5550019f68547e90913f75791ee96"
Unverified Commit 5366db5d authored by kesimeg's avatar kesimeg Committed by GitHub
Browse files

fix une2td ignoring class_labels (#5401)



* fix une2td ignoring class_labels

* unet2.py error message updated

* style and quality changes

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent e5168588
......@@ -291,6 +291,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment