Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a8440653
Unverified
Commit
a8440653
authored
Oct 09, 2023
by
Jake Vanderplas
Committed by
GitHub
Oct 09, 2023
Browse files
replace references to deprecated KeyArray & PRNGKeyArray (#5324)
parent
35952e61
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
28 additions
and
26 deletions
+28
-26
setup.py
setup.py
+2
-2
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+2
-2
src/diffusers/models/controlnet_flax.py
src/diffusers/models/controlnet_flax.py
+1
-1
src/diffusers/models/modeling_flax_utils.py
src/diffusers/models/modeling_flax_utils.py
+1
-1
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+1
-1
src/diffusers/models/vae_flax.py
src/diffusers/models/vae_flax.py
+1
-1
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
...iffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+2
-2
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
...table_diffusion/pipeline_flax_stable_diffusion_img2img.py
+3
-3
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
...table_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+2
-2
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
...ffusers/pipelines/stable_diffusion/safety_checker_flax.py
+1
-1
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
.../stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+2
-2
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+2
-2
src/diffusers/schedulers/scheduling_karras_ve_flax.py
src/diffusers/schedulers/scheduling_karras_ve_flax.py
+2
-1
src/diffusers/schedulers/scheduling_sde_ve_flax.py
src/diffusers/schedulers/scheduling_sde_ve_flax.py
+3
-2
No files found.
setup.py
View file @
a8440653
...
@@ -102,8 +102,8 @@ _deps = [
...
@@ -102,8 +102,8 @@ _deps = [
"importlib_metadata"
,
"importlib_metadata"
,
"invisible-watermark>=0.2.0"
,
"invisible-watermark>=0.2.0"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.
2.8,!=0.3.2
"
,
"jax>=0.
4.1
"
,
"jaxlib>=0.
1.65
"
,
"jaxlib>=0.
4.1
"
,
"Jinja2"
,
"Jinja2"
,
"k-diffusion>=0.0.12"
,
"k-diffusion>=0.0.12"
,
"torchsde"
,
"torchsde"
,
...
...
src/diffusers/dependency_versions_table.py
View file @
a8440653
...
@@ -15,8 +15,8 @@ deps = {
...
@@ -15,8 +15,8 @@ deps = {
"importlib_metadata"
:
"importlib_metadata"
,
"importlib_metadata"
:
"importlib_metadata"
,
"invisible-watermark"
:
"invisible-watermark>=0.2.0"
,
"invisible-watermark"
:
"invisible-watermark>=0.2.0"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.
2.8,!=0.3.2
"
,
"jax"
:
"jax>=0.
4.1
"
,
"jaxlib"
:
"jaxlib>=0.
1.65
"
,
"jaxlib"
:
"jaxlib>=0.
4.1
"
,
"Jinja2"
:
"Jinja2"
,
"Jinja2"
:
"Jinja2"
,
"k-diffusion"
:
"k-diffusion>=0.0.12"
,
"k-diffusion"
:
"k-diffusion>=0.0.12"
,
"torchsde"
:
"torchsde"
,
"torchsde"
:
"torchsde"
,
...
...
src/diffusers/models/controlnet_flax.py
View file @
a8440653
...
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order
:
str
=
"rgb"
controlnet_conditioning_channel_order
:
str
=
"rgb"
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/modeling_flax_utils.py
View file @
a8440653
...
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
...
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
```"""
```"""
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
return
self
.
_cast_floating_to
(
params
,
jnp
.
float16
,
mask
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
Dict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
Dict
:
raise
NotImplementedError
(
f
"init_weights method has to be implemented for
{
self
}
"
)
raise
NotImplementedError
(
f
"init_weights method has to be implemented for
{
self
}
"
)
@
classmethod
@
classmethod
...
...
src/diffusers/models/unet_2d_condition_flax.py
View file @
a8440653
...
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads
:
int
=
64
addition_embed_type_num_heads
:
int
=
64
projection_class_embeddings_input_dim
:
Optional
[
int
]
=
None
projection_class_embeddings_input_dim
:
Optional
[
int
]
=
None
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/models/vae_flax.py
View file @
a8440653
...
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
)
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
)
->
FrozenDict
:
# init input tensors
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
...
...
src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
View file @
a8440653
...
@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
...
@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
guidance_scale
:
float
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
...
@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
...
@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
latents
:
jnp
.
array
=
None
,
latents
:
jnp
.
array
=
None
,
...
@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
...
@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
a8440653
...
@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
self
,
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
...
@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self
,
self
,
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
View file @
a8440653
...
@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
...
@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
start_timestep
:
int
,
start_timestep
:
int
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
height
:
int
,
height
:
int
,
...
@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
...
@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
strength
:
float
=
0.8
,
strength
:
float
=
0.8
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
...
@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
...
@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Array representing an image batch to be used as the starting point.
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
Dictionary containing the model parameters/weights.
prng_seed (`jax.
random.Key
Array` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
View file @
a8440653
...
@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
...
@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
...
@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
...
@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask
:
jnp
.
array
,
mask
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
masked_image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
...
...
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
View file @
a8440653
...
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
...
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
Key
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
def
init_weights
(
self
,
rng
:
jax
.
Array
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
# init input tensor
# init input tensor
clip_input
=
jax
.
random
.
normal
(
rng
,
input_shape
)
clip_input
=
jax
.
random
.
normal
(
rng
,
input_shape
)
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
View file @
a8440653
...
@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
...
@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
self
,
prompt_ids
:
jax
.
Array
,
prompt_ids
:
jax
.
Array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jax
.
Array
]
=
7.5
,
guidance_scale
:
Union
[
float
,
jax
.
Array
]
=
7.5
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
...
@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
...
@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self
,
self
,
prompt_ids
:
jnp
.
array
,
prompt_ids
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
Key
Array
,
prng_seed
:
jax
.
Array
,
num_inference_steps
:
int
,
num_inference_steps
:
int
,
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
a8440653
...
@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
key
:
Optional
[
jax
.
random
.
Key
Array
]
=
None
,
key
:
Optional
[
jax
.
Array
]
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxDDPMSchedulerOutput
,
Tuple
]:
)
->
Union
[
FlaxDDPMSchedulerOutput
,
Tuple
]:
"""
"""
...
@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
current instance of sample being created by diffusion process.
key (`jax.
random.Key
Array`): a PRNG key.
key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns:
Returns:
...
...
src/diffusers/schedulers/scheduling_karras_ve_flax.py
View file @
a8440653
...
@@ -17,6 +17,7 @@ from dataclasses import dataclass
...
@@ -17,6 +17,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
flax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
random
from
jax
import
random
...
@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
KarrasVeSchedulerState
,
state
:
KarrasVeSchedulerState
,
sample
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
sigma
:
float
,
sigma
:
float
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
)
->
Tuple
[
jnp
.
ndarray
,
float
]:
)
->
Tuple
[
jnp
.
ndarray
,
float
]:
"""
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
...
...
src/diffusers/schedulers/scheduling_sde_ve_flax.py
View file @
a8440653
...
@@ -18,6 +18,7 @@ from dataclasses import dataclass
...
@@ -18,6 +18,7 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
flax
import
flax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
random
from
jax
import
random
...
@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output
:
jnp
.
ndarray
,
model_output
:
jnp
.
ndarray
,
timestep
:
int
,
timestep
:
int
,
sample
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
"""
...
@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state
:
ScoreSdeVeSchedulerState
,
state
:
ScoreSdeVeSchedulerState
,
model_output
:
jnp
.
ndarray
,
model_output
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
sample
:
jnp
.
ndarray
,
key
:
random
.
Key
Array
,
key
:
jax
.
Array
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
)
->
Union
[
FlaxSdeVeOutput
,
Tuple
]:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment