Unverified Commit a8440653 authored by Jake Vanderplas's avatar Jake Vanderplas Committed by GitHub
Browse files

replace references to deprecated KeyArray & PRNGKeyArray (#5324)

parent 35952e61
...@@ -102,8 +102,8 @@ _deps = [ ...@@ -102,8 +102,8 @@ _deps = [
"importlib_metadata", "importlib_metadata",
"invisible-watermark>=0.2.0", "invisible-watermark>=0.2.0",
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.4.1",
"jaxlib>=0.1.65", "jaxlib>=0.4.1",
"Jinja2", "Jinja2",
"k-diffusion>=0.0.12", "k-diffusion>=0.0.12",
"torchsde", "torchsde",
......
...@@ -15,8 +15,8 @@ deps = { ...@@ -15,8 +15,8 @@ deps = {
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0", "invisible-watermark": "invisible-watermark>=0.2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.1.65", "jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2", "Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12", "k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde", "torchsde": "torchsde",
......
...@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb" controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors # init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32) sample = jnp.zeros(sample_shape, dtype=jnp.float32)
......
...@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
```""" ```"""
return self._cast_floating_to(params, jnp.float16, mask) return self._cast_floating_to(params, jnp.float16, mask)
def init_weights(self, rng: jax.random.KeyArray) -> Dict: def init_weights(self, rng: jax.Array) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}") raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod @classmethod
......
...@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads: int = 64 addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None projection_class_embeddings_input_dim: Optional[int] = None
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors # init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32) sample = jnp.zeros(sample_shape, dtype=jnp.float32)
......
...@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype=self.dtype, dtype=self.dtype,
) )
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors # init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32) sample = jnp.zeros(sample_shape, dtype=jnp.float32)
......
...@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids: jnp.array, prompt_ids: jnp.array,
image: jnp.array, image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
guidance_scale: float, guidance_scale: float,
latents: Optional[jnp.array] = None, latents: Optional[jnp.array] = None,
...@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
prompt_ids: jnp.array, prompt_ids: jnp.array,
image: jnp.array, image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5, guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None, latents: jnp.array = None,
...@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation. Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`): params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights. Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`): prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key. Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
......
...@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
height: int, height: int,
width: int, width: int,
...@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
......
...@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids: jnp.array, prompt_ids: jnp.array,
image: jnp.array, image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
start_timestep: int, start_timestep: int,
num_inference_steps: int, num_inference_steps: int,
height: int, height: int,
...@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
prompt_ids: jnp.array, prompt_ids: jnp.array,
image: jnp.array, image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
Array representing an image batch to be used as the starting point. Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`): params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights. Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`): prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key. Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
......
...@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask: jnp.array, mask: jnp.array,
masked_image: jnp.array, masked_image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
height: int, height: int,
width: int, width: int,
...@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
mask: jnp.array, mask: jnp.array,
masked_image: jnp.array, masked_image: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
......
...@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): ...@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor # init input tensor
clip_input = jax.random.normal(rng, input_shape) clip_input = jax.random.normal(rng, input_shape)
......
...@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): ...@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self, self,
prompt_ids: jax.Array, prompt_ids: jax.Array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5, guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): ...@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
height: int, height: int,
width: int, width: int,
......
...@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
key: Optional[jax.random.KeyArray] = None, key: Optional[jax.Array] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]: ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
""" """
...@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
key (`jax.random.KeyArray`): a PRNG key. key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns: Returns:
......
...@@ -17,6 +17,7 @@ from dataclasses import dataclass ...@@ -17,6 +17,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import flax import flax
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import random from jax import random
...@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state: KarrasVeSchedulerState, state: KarrasVeSchedulerState,
sample: jnp.ndarray, sample: jnp.ndarray,
sigma: float, sigma: float,
key: random.KeyArray, key: jax.Array,
) -> Tuple[jnp.ndarray, float]: ) -> Tuple[jnp.ndarray, float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
......
...@@ -18,6 +18,7 @@ from dataclasses import dataclass ...@@ -18,6 +18,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import flax import flax
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import random from jax import random
...@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output: jnp.ndarray, model_output: jnp.ndarray,
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
key: random.KeyArray, key: jax.Array,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]: ) -> Union[FlaxSdeVeOutput, Tuple]:
""" """
...@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state: ScoreSdeVeSchedulerState, state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray, model_output: jnp.ndarray,
sample: jnp.ndarray, sample: jnp.ndarray,
key: random.KeyArray, key: jax.Array,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]: ) -> Union[FlaxSdeVeOutput, Tuple]:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment