Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fc188391
Unverified
Commit
fc188391
authored
Apr 18, 2023
by
Will Berman
Committed by
GitHub
Apr 18, 2023
Browse files
class labels timestep embeddings projection dtype cast (#3137)
This mimics the dtype cast for the standard time embeddings
parent
f0c74e9a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
2 deletions
+10
-2
src/diffusers/models/unet_2d_condition.py
src/diffusers/models/unet_2d_condition.py
+5
-1
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+5
-1
No files found.
src/diffusers/models/unet_2d_condition.py
View file @
fc188391
...
@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -659,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb
=
self
.
time_proj
(
timesteps
)
t_emb
=
self
.
time_proj
(
timesteps
)
#
t
imesteps does not contain any weights and will always return f32 tensors
#
`T
imesteps
`
does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# there might be better ways to encapsulate this.
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
...
@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
...
@@ -673,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if
self
.
config
.
class_embed_type
==
"timestep"
:
if
self
.
config
.
class_embed_type
==
"timestep"
:
class_labels
=
self
.
time_proj
(
class_labels
)
class_labels
=
self
.
time_proj
(
class_labels
)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels
=
class_labels
.
to
(
dtype
=
sample
.
dtype
)
class_emb
=
self
.
class_embedding
(
class_labels
).
to
(
dtype
=
self
.
dtype
)
class_emb
=
self
.
class_embedding
(
class_labels
).
to
(
dtype
=
self
.
dtype
)
if
self
.
config
.
class_embeddings_concat
:
if
self
.
config
.
class_embeddings_concat
:
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
fc188391
...
@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -756,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb
=
self
.
time_proj
(
timesteps
)
t_emb
=
self
.
time_proj
(
timesteps
)
#
t
imesteps does not contain any weights and will always return f32 tensors
#
`T
imesteps
`
does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# there might be better ways to encapsulate this.
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
...
@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
...
@@ -770,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if
self
.
config
.
class_embed_type
==
"timestep"
:
if
self
.
config
.
class_embed_type
==
"timestep"
:
class_labels
=
self
.
time_proj
(
class_labels
)
class_labels
=
self
.
time_proj
(
class_labels
)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels
=
class_labels
.
to
(
dtype
=
sample
.
dtype
)
class_emb
=
self
.
class_embedding
(
class_labels
).
to
(
dtype
=
self
.
dtype
)
class_emb
=
self
.
class_embedding
(
class_labels
).
to
(
dtype
=
self
.
dtype
)
if
self
.
config
.
class_embeddings_concat
:
if
self
.
config
.
class_embeddings_concat
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment