Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a8440653
Unverified
Commit
a8440653
authored
Oct 09, 2023
by
Jake Vanderplas
Committed by
GitHub
Oct 09, 2023
Browse files
replace references to deprecated KeyArray & PRNGKeyArray (#5324)
parent
35952e61
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
28 additions
and
26 deletions
+28
-26
setup.py
setup.py
+2
-2
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+2
-2
src/diffusers/models/controlnet_flax.py
src/diffusers/models/controlnet_flax.py
+1
-1
src/diffusers/models/modeling_flax_utils.py
src/diffusers/models/modeling_flax_utils.py
+1
-1
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+1
-1
src/diffusers/models/vae_flax.py
src/diffusers/models/vae_flax.py
+1
-1
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
...iffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+2
-2
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
...table_diffusion/pipeline_flax_stable_diffusion_img2img.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
...table_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+2
-2
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
...ffusers/pipelines/stable_diffusion/safety_checker_flax.py
+1
-1
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
.../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+2
-1
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+3
-2
No files found.
setup.py
View file @
a8440653
...
...
@@ -102,8 +102,8 @@ _deps = [
"importlib_metadata"
,
"invisible-watermark>=0.2.0"
,
"isort>=5.5.4"
,
"jax>=0.
2.8,!=0.3.2
"
,
"jaxlib>=0.
1.65
"
,
"jax>=0.
4.1
"
,
"jaxlib>=0.
4.1
"
,
"Jinja2"
,
"k-diffusion>=0.0.12"
,
"torchsde"
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a8440653
...
...
@@ -15,8 +15,8 @@ deps = {
"importlib_metadata"
:
"importlib_metadata"
,
"invisible-watermark"
:
"invisible-watermark>=0.2.0"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.
2.8,!=0.3.2
"
,
"jaxlib"
:
"jaxlib>=0.
1.65
"
,
"jax"
:
"jax>=0.
4.1
"
,
"jaxlib"
:
"jaxlib>=0.
4.1
"
,
"Jinja2"
:
"Jinja2"
,
"k-diffusion"
:
"k-diffusion>=0.0.12"
,
"torchsde"
:
"torchsde"
,
...
...
src/diffusers/models/controlnet_flax.py
View file @
a8440653
...
...
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order
:
str
=
"rgb"
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/modeling_flax_utils.py
View file @
a8440653
...
...
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
```"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
Dict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
Dict
:
raise
NotImplementedError
(
f
"init_weights method has to be implemented for
{
self
}
"
)
@
classmethod
...
...
src/diffusers/models/unet_2d_condition_flax.py
View file @
a8440653
...
...
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads
:
int
=
64
projection_class_embeddings_input_dim
:
Optional
[
int
]
=
None
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/vae_flax.py
View file @
a8440653
...
...
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype
=
self
.
dtype
,
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
View file @
a8440653
...
...
@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
...
...
@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
latents
:
jnp
.
array
=
None
,
...
...
@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
a8440653
...
...
@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
View file @
a8440653
...
...
@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
start_timestep
:
int
,
num_inference_steps
:
int
,
height
:
int
,
...
...
@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
strength
:
float
=
0.8
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
...
...
@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
View file @
a8440653
...
...
@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
View file @
a8440653
...
...
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
# init input tensor
clip_input
=
jax
.
random
.
normal
(
rng
,
input_shape
)
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
View file @
a8440653
...
...
@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jax
.
Array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jax
.
Array
]
=
7.5
,
height
:
Optional
[
int
]
=
None
,
...
...
@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
height
:
int
,
width
:
int
,
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
a8440653
...
...
@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
key
:
Optional
[
jax
.
random
.
Key
Array
]
=
None
,
key
:
Optional
[
jax
.
Array
]
=
None
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxDDPMSchedulerOutput
,
Tuple
]:
"""
...
...
@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`jax.
random.Key
Array`): a PRNG key.
key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
a8440653
...
...
@@ -17,6 +17,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
jax
import
jax.numpy
as
jnp
from
jax
import
random
...
...
@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
KarrasVeSchedulerState
,
sample
:
jnp
.
ndarray
,
sigma
:
float
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
)
->
Tuple
[
jnp
.
ndarray
,
float
]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
a8440653
...
...
@@ -18,6 +18,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
jax
import
jax.numpy
as
jnp
from
jax
import
random
...
...
@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
...
...
@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
ScoreSdeVeSchedulerState
,
model_output
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment