Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a8440653
Unverified
Commit
a8440653
authored
Oct 09, 2023
by
Jake Vanderplas
Committed by
GitHub
Oct 09, 2023
Browse files
replace references to deprecated KeyArray & PRNGKeyArray (#5324)
parent
35952e61
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
28 additions
and
26 deletions
+28
-26
setup.py
setup.py
+2
-2
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+2
-2
src/diffusers/models/controlnet_flax.py
src/diffusers/models/controlnet_flax.py
+1
-1
src/diffusers/models/modeling_flax_utils.py
src/diffusers/models/modeling_flax_utils.py
+1
-1
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+1
-1
src/diffusers/models/vae_flax.py
src/diffusers/models/vae_flax.py
+1
-1
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
...iffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+2
-2
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
...table_diffusion/pipeline_flax_stable_diffusion_img2img.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
...table_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+2
-2
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
...ffusers/pipelines/stable_diffusion/safety_checker_flax.py
+1
-1
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
.../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+2
-1
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+3
-2
No files found.
setup.py
View file @
a8440653
...
...
@@ -102,8 +102,8 @@ _deps = [
"importlib_metadata"
,
"invisible-watermark>=0.2.0"
,
"isort>=5.5.4"
,
"jax>=0.
2.8,!=0.3.2
"
,
"jaxlib>=0.
1.65
"
,
"jax>=0.
4.1
"
,
"jaxlib>=0.
4.1
"
,
"Jinja2"
,
"k-diffusion>=0.0.12"
,
"torchsde"
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a8440653
...
...
@@ -15,8 +15,8 @@ deps = {
"importlib_metadata"
:
"importlib_metadata"
,
"invisible-watermark"
:
"invisible-watermark>=0.2.0"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.
2.8,!=0.3.2
"
,
"jaxlib"
:
"jaxlib>=0.
1.65
"
,
"jax"
:
"jax>=0.
4.1
"
,
"jaxlib"
:
"jaxlib>=0.
4.1
"
,
"Jinja2"
:
"Jinja2"
,
"k-diffusion"
:
"k-diffusion>=0.0.12"
,
"torchsde"
:
"torchsde"
,
...
...
src/diffusers/models/controlnet_flax.py
View file @
a8440653
...
...
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order
:
str
=
"rgb"
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/modeling_flax_utils.py
View file @
a8440653
...
...
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
```"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
Dict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
Dict
:
raise
NotImplementedError
(
f
"init_weights method has to be implemented for
{
self
}
"
)
@
classmethod
...
...
src/diffusers/models/unet_2d_condition_flax.py
View file @
a8440653
...
...
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads
:
int
=
64
projection_class_embeddings_input_dim
:
Optional
[
int
]
=
None
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/vae_flax.py
View file @
a8440653
...
...
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype
=
self
.
dtype
,
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
View file @
a8440653
...
...
@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
...
...
@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
latents
:
jnp
.
array
=
None
,
...
...
@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
a8440653
...
...
@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
View file @
a8440653
...
...
@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
start_timestep
:
int
,
num_inference_steps
:
int
,
height
:
int
,
...
...
@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
strength
:
float
=
0.8
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
...
...
@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
View file @
a8440653
...
...
@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
View file @
a8440653
...
...
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
# init input tensor
clip_input
=
jax
.
random
.
normal
(
rng
,
input_shape
)
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
View file @
a8440653
...
...
@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jax
.
Array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jax
.
Array
]
=
7.5
,
height
:
Optional
[
int
]
=
None
,
...
...
@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
a8440653
...
...
@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
key
:
Optional
[
jax
.
random
.
Key
Array
]
=
None
,
key
:
Optional
[
jax
.
Array
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxDDPMSchedulerOutput
,
Tuple
]:
"""
...
...
@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`jax.
random.Key
Array`): a PRNG key.
key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
a8440653
...
...
@@ -17,6 +17,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
jax
import
jax.numpy
as
jnp
from
jax
import
random
...
...
@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
KarrasVeSchedulerState
,
sample
:
jnp
.
ndarray
,
sigma
:
float
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
)
->
Tuple
[
jnp
.
ndarray
,
float
]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
a8440653
...
...
@@ -18,6 +18,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
jax
import
jax.numpy
as
jnp
from
jax
import
random
...
...
@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
...
...
@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
ScoreSdeVeSchedulerState
,
model_output
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment