Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2345481c
Unverified
Commit
2345481c
authored
Sep 20, 2022
by
Patrick von Platen
Committed by
GitHub
Sep 20, 2022
Browse files
[Flax] Fix unet and ddim scheduler (#594)
* [Flax] Fix unet and ddim scheduler * correct * finish
parent
d934d3d7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
14 deletions
+22
-14
src/diffusers/models/embeddings_flax.py
src/diffusers/models/embeddings_flax.py
+4
-3
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+2
-1
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+6
-8
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-0
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+0
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-0
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+5
-1
No files found.
src/diffusers/models/embeddings_flax.py
View file @
2345481c
...
...
@@ -19,7 +19,7 @@ import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def
get_sinusoidal_embeddings
(
timesteps
,
embedding_dim
):
def
get_sinusoidal_embeddings
(
timesteps
,
embedding_dim
,
freq_shift
:
float
=
1
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
...
...
@@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
freq_shift
)
emb
=
jnp
.
exp
(
jnp
.
arange
(
half_dim
)
*
-
emb
)
emb
=
timesteps
[:,
None
]
*
emb
[
None
,
:]
emb
=
jnp
.
concatenate
([
jnp
.
cos
(
emb
),
jnp
.
sin
(
emb
)],
-
1
)
...
...
@@ -50,7 +50,8 @@ class FlaxTimestepEmbedding(nn.Module):
class
FlaxTimesteps
(
nn
.
Module
):
dim
:
int
=
32
freq_shift
:
float
=
1
@
nn
.
compact
def
__call__
(
self
,
timesteps
):
return
get_sinusoidal_embeddings
(
timesteps
,
self
.
dim
)
return
get_sinusoidal_embeddings
(
timesteps
,
self
.
dim
,
freq_shift
=
self
.
freq_shift
)
src/diffusers/models/unet_2d_condition_flax.py
View file @
2345481c
...
...
@@ -73,6 +73,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
cross_attention_dim
:
int
=
1280
dropout
:
float
=
0.0
dtype
:
jnp
.
dtype
=
jnp
.
float32
freq_shift
:
int
=
0
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
)
->
FrozenDict
:
# init input tensors
...
...
@@ -100,7 +101,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
)
# time
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
])
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
]
,
freq_shift
=
self
.
config
.
freq_shift
)
self
.
time_embedding
=
FlaxTimestepEmbedding
(
time_embed_dim
,
dtype
=
self
.
dtype
)
# down
...
...
src/diffusers/pipeline_flax_utils.py
View file @
2345481c
...
...
@@ -354,7 +354,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# TODO(Patrick, Suraj) - delete later
if
class_name
==
"DummyChecker"
:
library_name
=
"stable_diffusion"
class_name
=
"StableDiffusionSafetyChecker"
class_name
=
"
Flax
StableDiffusionSafetyChecker"
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
...
...
@@ -421,16 +421,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model
=
cached_folder
if
issubclass
(
class_obj
,
FlaxModelMixin
):
# TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
loaded_sub_model
,
loaded_params
=
load_method
(
loadable_folder
,
from_pt
=
from_pt
,
dtype
=
dtype
)
params
[
name
]
=
loaded_params
elif
is_transformers_available
()
and
issubclass
(
class_obj
,
FlaxPreTrainedModel
):
# make sure we don't initialize the weights to save time
if
name
==
"safety_checker"
:
loaded_sub_model
=
DummyChecker
()
loaded_params
=
DummyChecker
()
else
:
loaded_sub_model
,
loaded_params
=
load_method
(
loadable_folder
,
from_pt
=
from_pt
,
dtype
=
dtype
)
params
[
name
]
=
loaded_params
elif
is_transformers_available
()
and
issubclass
(
class_obj
,
FlaxPreTrainedModel
):
# make sure we don't initialize the weights to save time
if
from_pt
:
elif
from_pt
:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model
=
load_method
(
loadable_folder
,
from_pt
=
from_pt
)
loaded_params
=
loaded_sub_model
.
params
...
...
src/diffusers/pipeline_utils.py
View file @
2345481c
...
...
@@ -341,6 +341,10 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if
class_name
.
startswith
(
"Flax"
):
class_name
=
class_name
[
4
:]
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
2345481c
...
...
@@ -178,7 +178,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
rngs
=
{},
).
sample
# perform guidance
noise_pred_uncond
,
noise_prediction_text
=
jnp
.
split
(
noise_pred
,
2
,
axis
=
0
)
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
2345481c
...
...
@@ -222,6 +222,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_timestep
]
if
prev_timestep
>=
0
else
self
.
final_alpha_cumprod
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
2345481c
...
...
@@ -216,6 +216,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta
=
0.0
# 1. get previous step value (=t-1)
prev_timestep
=
timestep
-
self
.
config
.
num_train_timesteps
//
state
.
num_inference_steps
...
...
@@ -224,6 +227,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t
=
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
jnp
.
where
(
prev_timestep
>=
0
,
alphas_cumprod
[
prev_timestep
],
self
.
final_alpha_cumprod
)
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
@@ -233,7 +237,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
,
alphas_cumprod
)
std_dev_t
=
variance
**
(
0.5
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
model_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment