Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
797b290e
Unverified
Commit
797b290e
authored
Oct 11, 2022
by
Suraj Patil
Committed by
GitHub
Oct 11, 2022
Browse files
support bf16 for stable diffusion (#792)
* support bf16 for stable diffusion * fix typo * address review comments
parent
81bdbb5e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+11
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+3
-1
src/diffusers/pipelines/stable_diffusion/safety_checker.py
src/diffusers/pipelines/stable_diffusion/safety_checker.py
+3
-2
No files found.
src/diffusers/models/resnet.py
View file @
797b290e
...
@@ -41,6 +41,13 @@ class Upsample2D(nn.Module):
...
@@ -41,6 +41,13 @@ class Upsample2D(nn.Module):
if
self
.
use_conv_transpose
:
if
self
.
use_conv_transpose
:
return
self
.
conv
(
hidden_states
)
return
self
.
conv
(
hidden_states
)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# if `output_size` is passed we force the interpolation output
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
# size and do not make use of `scale_factor=2`
if
output_size
is
None
:
if
output_size
is
None
:
...
@@ -48,6 +55,10 @@ class Upsample2D(nn.Module):
...
@@ -48,6 +55,10 @@ class Upsample2D(nn.Module):
else
:
else
:
hidden_states
=
F
.
interpolate
(
hidden_states
,
size
=
output_size
,
mode
=
"nearest"
)
hidden_states
=
F
.
interpolate
(
hidden_states
,
size
=
output_size
,
mode
=
"nearest"
)
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
hidden_states
=
hidden_states
.
to
(
dtype
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if
self
.
use_conv
:
if
self
.
use_conv
:
if
self
.
name
==
"conv"
:
if
self
.
name
==
"conv"
:
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
797b290e
...
@@ -327,7 +327,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -327,7 +327,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
self
.
vae
.
decode
(
latents
).
sample
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
self
.
device
)
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
image
,
has_nsfw_concept
=
self
.
safety_checker
(
...
...
src/diffusers/pipelines/stable_diffusion/safety_checker.py
View file @
797b290e
...
@@ -38,8 +38,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
...
@@ -38,8 +38,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
pooled_output
=
self
.
vision_model
(
clip_input
)[
1
]
# pooled_output
pooled_output
=
self
.
vision_model
(
clip_input
)[
1
]
# pooled_output
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
self
.
visual_projection
(
pooled_output
)
special_cos_dist
=
cosine_distance
(
image_embeds
,
self
.
special_care_embeds
).
cpu
().
numpy
()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
cos_dist
=
cosine_distance
(
image_embeds
,
self
.
concept_embeds
).
cpu
().
numpy
()
special_cos_dist
=
cosine_distance
(
image_embeds
,
self
.
special_care_embeds
).
cpu
().
float
().
numpy
()
cos_dist
=
cosine_distance
(
image_embeds
,
self
.
concept_embeds
).
cpu
().
float
().
numpy
()
result
=
[]
result
=
[]
batch_size
=
image_embeds
.
shape
[
0
]
batch_size
=
image_embeds
.
shape
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment