Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2345481c
Unverified
Commit
2345481c
authored
Sep 20, 2022
by
Patrick von Platen
Committed by
GitHub
Sep 20, 2022
Browse files
[Flax] Fix unet and ddim scheduler (#594)
* [Flax] Fix unet and ddim scheduler * correct * finish
parent
d934d3d7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
14 deletions
+22
-14
src/diffusers/models/embeddings_flax.py
src/diffusers/models/embeddings_flax.py
+4
-3
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+2
-1
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+6
-8
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-0
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+0
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-0
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+5
-1
No files found.
src/diffusers/models/embeddings_flax.py
View file @
2345481c
...
...
@@ -19,7 +19,7 @@ import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def
get_sinusoidal_embeddings
(
timesteps
,
embedding_dim
):
def
get_sinusoidal_embeddings
(
timesteps
,
embedding_dim
,
freq_shift
:
float
=
1
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
...
...
@@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
freq_shift
)
emb
=
jnp
.
exp
(
jnp
.
arange
(
half_dim
)
*
-
emb
)
emb
=
timesteps
[:,
None
]
*
emb
[
None
,
:]
emb
=
jnp
.
concatenate
([
jnp
.
cos
(
emb
),
jnp
.
sin
(
emb
)],
-
1
)
...
...
@@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module):
class
FlaxTimesteps
(
nn
.
Module
):
dim
:
int
=
32
freq_shift
:
float
=
1
@
nn
.
compact
def
__call__
(
self
,
timesteps
):
return
get_sinusoidal_embeddings
(
timesteps
,
self
.
dim
)
return
get_sinusoidal_embeddings
(
timesteps
,
self
.
dim
,
freq_shift
=
self
.
freq_shift
)
src/diffusers/models/unet_2d_condition_flax.py
View file @
2345481c
...
...
@@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
cross_attention_dim
:
int
=
1280
dropout
:
float
=
0.0
dtype
:
jnp
.
dtype
=
jnp
.
float32
freq_shift
:
int
=
0
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
)
->
FrozenDict
:
# init input tensors
...
...
@@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
)
# time
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
])
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
]
,
freq_shift
=
self
.
config
.
freq_shift
)
self
.
time_embedding
=
FlaxTimestepEmbedding
(
time_embed_dim
,
dtype
=
self
.
dtype
)
# down
...
...
src/diffusers/pipeline_flax_utils.py
View file @
2345481c
...
...
@@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# TODO(Patrick, Suraj) - delete later
if
class_name
==
"DummyChecker"
:
library_name
=
"stable_diffusion"
class_name
=
"StableDiffusionSafetyChecker"
class_name
=
"
Flax
StableDiffusionSafetyChecker"
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
...
...
@@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model
=
cached_folder
if
issubclass
(
class_obj
,
FlaxModelMixin
):
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
if
name
==
"safety_checker"
:
loaded_sub_model
=
DummyChecker
()
loaded_params
=
DummyChecker
()
else
:
loaded_sub_model
,
loaded_params
=
load_method
(
loadable_folder
,
from_pt
=
from_pt
,
dtype
=
dtype
)
params
[
name
]
=
loaded_params
elif
is_transformers_available
()
and
issubclass
(
class_obj
,
FlaxPreTrainedModel
):
# make sure we don't initialize the weights to save time
if
from_pt
:
if
name
==
"safety_checker"
:
loaded_sub_model
=
DummyChecker
()
loaded_params
=
DummyChecker
()
elif
from_pt
:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model
=
load_method
(
loadable_folder
,
from_pt
=
from_pt
)
loaded_params
=
loaded_sub_model
.
params
...
...
src/diffusers/pipeline_utils.py
View file @
2345481c
...
...
@@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if
class_name
.
startswith
(
"Flax"
):
class_name
=
class_name
[
4
:]
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
2345481c
...
...
@@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
rngs
=
{},
).
sample
# perform guidance
noise_pred_uncond
,
noise_prediction_text
=
jnp
.
split
(
noise_pred
,
2
,
axis
=
0
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
2345481c
...
...
@@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_timestep
]
if
prev_timestep
>=
0
else
self
.
final_alpha_cumprod
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
2345481c
...
...
@@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta
=
0.0
# 1. get previous step value (=t-1)
prev_timestep
=
timestep
-
self
.
config
.
num_train_timesteps
//
state
.
num_inference_steps
...
...
@@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t
=
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
jnp
.
where
(
prev_timestep
>=
0
,
alphas_cumprod
[
prev_timestep
],
self
.
final_alpha_cumprod
)
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
@@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
,
alphas_cumprod
)
std_dev_t
=
variance
**
(
0.5
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
model_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment