Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
797b290e
Unverified
Commit
797b290e
authored
Oct 11, 2022
by
Suraj Patil
Committed by
GitHub
Oct 11, 2022
Browse files
support bf16 for stable diffusion (#792)
* support bf16 for stable diffusion * fix typo * address review comments
parent
81bdbb5e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+11
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+3
-1
src/diffusers/pipelines/stable_diffusion/safety_checker.py
src/diffusers/pipelines/stable_diffusion/safety_checker.py
+3
-2
No files found.
src/diffusers/models/resnet.py
View file @
797b290e
...
...
@@ -41,6 +41,13 @@ class Upsample2D(nn.Module):
if
self
.
use_conv_transpose
:
return
self
.
conv
(
hidden_states
)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if
output_size
is
None
:
...
...
@@ -48,6 +55,10 @@ class Upsample2D(nn.Module):
else
:
hidden_states
=
F
.
interpolate
(
hidden_states
,
size
=
output_size
,
mode
=
"nearest"
)
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
dtype
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
use_conv
:
if
self
.
name
==
"conv"
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
797b290e
...
...
@@ -327,7 +327,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
...
...
src/diffusers/pipelines/stable_diffusion/safety_checker.py
View file @
797b290e
...
...
@@ -38,8 +38,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
pooled_output
=
self
.
vision_model
(
clip_input
)[
1
]
# pooled_output
image_embeds
=
self
.
visual_projection
(
pooled_output
)
special_cos_dist
=
cosine_distance
(
image_embeds
,
self
.
special_care_embeds
).
cpu
().
numpy
()
cos_dist
=
cosine_distance
(
image_embeds
,
self
.
concept_embeds
).
cpu
().
numpy
()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist
=
cosine_distance
(
image_embeds
,
self
.
special_care_embeds
).
cpu
().
float
().
numpy
()
cos_dist
=
cosine_distance
(
image_embeds
,
self
.
concept_embeds
).
cpu
().
float
().
numpy
()
result
=
[]
batch_size
=
image_embeds
.
shape
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment