Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
c82f7baf
Unverified
Commit
c82f7baf
authored
Sep 26, 2023
by
Patrick von Platen
Committed by
GitHub
Sep 26, 2023
Browse files
[SDXL Flax] fix SDXL flax init (#5187)
* fix SDXL flax init * finish * Fix
parent
d9e7857a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
4 deletions
+15
-4
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+12
-2
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
.../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+3
-2
No files found.
src/diffusers/models/unet_2d_condition_flax.py
View file @
c82f7baf
...
@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
added_cond_kwargs
=
None
added_cond_kwargs
=
None
if
self
.
addition_embed_type
==
"text_time"
:
if
self
.
addition_embed_type
==
"text_time"
:
# TODO: how to get this from the config? It's no longer cross_attention_dim
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
text_embeds_dim
=
1280
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
is_refiner
=
(
5
*
self
.
config
.
addition_time_embed_dim
+
self
.
config
.
cross_attention_dim
==
self
.
config
.
projection_class_embeddings_input_dim
)
num_micro_conditions
=
5
if
is_refiner
else
6
text_embeds_dim
=
self
.
config
.
projection_class_embeddings_input_dim
-
(
num_micro_conditions
*
self
.
config
.
addition_time_embed_dim
)
time_ids_channels
=
self
.
projection_class_embeddings_input_dim
-
text_embeds_dim
time_ids_channels
=
self
.
projection_class_embeddings_input_dim
-
text_embeds_dim
time_ids_dims
=
time_ids_channels
//
self
.
addition_time_embed_dim
time_ids_dims
=
time_ids_channels
//
self
.
addition_time_embed_dim
added_cond_kwargs
=
{
added_cond_kwargs
=
{
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
View file @
c82f7baf
...
@@ -215,14 +215,15 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
...
@@ -215,14 +215,15 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
params
[
"scheduler"
].
init_noise_sigma
# Prepare scheduler state
# Prepare scheduler state
scheduler_state
=
self
.
scheduler
.
set_timesteps
(
scheduler_state
=
self
.
scheduler
.
set_timesteps
(
params
[
"scheduler"
],
num_inference_steps
=
num_inference_steps
,
shape
=
latents
.
shape
params
[
"scheduler"
],
num_inference_steps
=
num_inference_steps
,
shape
=
latents
.
shape
)
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
scheduler_state
.
init_noise_sigma
added_cond_kwargs
=
{
"text_embeds"
:
add_text_embeds
,
"time_ids"
:
add_time_ids
}
added_cond_kwargs
=
{
"text_embeds"
:
add_text_embeds
,
"time_ids"
:
add_time_ids
}
# Denoising loop
# Denoising loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment