Unverified Commit 4ac205e3 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

[JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#4719)

`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html


so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.
Co-authored-by: default avatarPeter Hawkins <phawkins@google.com>
parent ed2f9560
...@@ -238,14 +238,14 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -238,14 +238,14 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
def _generate( def _generate(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
image: jnp.array, image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
guidance_scale: float, guidance_scale: float,
latents: Optional[jnp.array] = None, latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.ndarray] = None,
controlnet_conditioning_scale: float = 1.0, controlnet_conditioning_scale: float = 1.0,
): ):
height, width = image.shape[-2:] height, width = image.shape[-2:]
...@@ -348,15 +348,15 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -348,15 +348,15 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
image: jnp.array, image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5, guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.array = None, latents: jnp.ndarray = None,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.ndarray = None,
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
): ):
...@@ -364,13 +364,13 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -364,13 +364,13 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
The call function to the pipeline for generation. The call function to the pipeline for generation.
Args: Args:
prompt_ids (`jnp.array`): prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation.
image (`jnp.array`): image (`jnp.ndarray`):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation. Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`): params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights. Dictionary containing the model parameters/weights.
prng_seed (`jax.Array` or `jax.Array`): prng_seed (`jax.Array`):
Array containing random number generator key. Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...@@ -378,11 +378,11 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -378,11 +378,11 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*): latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`. array is generated by sampling using the supplied random `generator`.
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. to the residual in the original `unet`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
......
...@@ -220,8 +220,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -220,8 +220,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
height: int, height: int,
width: int, width: int,
guidance_scale: float, guidance_scale: float,
latents: Optional[jnp.array] = None, latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.ndarray] = None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -316,9 +316,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -316,9 +316,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5, guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.array = None, latents: jnp.ndarray = None,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
): ):
...@@ -338,7 +338,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -338,7 +338,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*): latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`. array is generated by sampling using the supplied random `generator`.
......
...@@ -232,8 +232,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -232,8 +232,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
def _generate( def _generate(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
image: jnp.array, image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
start_timestep: int, start_timestep: int,
...@@ -241,8 +241,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -241,8 +241,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
height: int, height: int,
width: int, width: int,
guidance_scale: float, guidance_scale: float,
noise: Optional[jnp.array] = None, noise: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.ndarray] = None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -337,17 +337,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -337,17 +337,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
image: jnp.array, image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5, guidance_scale: Union[float, jnp.ndarray] = 7.5,
noise: jnp.array = None, noise: jnp.ndarray = None,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
): ):
...@@ -355,9 +355,9 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -355,9 +355,9 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
The call function to the pipeline for generation. The call function to the pipeline for generation.
Args: Args:
prompt_ids (`jnp.array`): prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide image generation. The prompt or prompts to guide image generation.
image (`jnp.array`): image (`jnp.ndarray`):
Array representing an image batch to be used as the starting point. Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`): params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights. Dictionary containing the model parameters/weights.
...@@ -379,7 +379,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -379,7 +379,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
noise (`jnp.array`, *optional*): noise (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. The array is generated by generation. Can be used to tweak the same generation with different prompts. The array is generated by
sampling using the supplied random `generator`. sampling using the supplied random `generator`.
......
...@@ -266,17 +266,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -266,17 +266,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
def _generate( def _generate(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
mask: jnp.array, mask: jnp.ndarray,
masked_image: jnp.array, masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
num_inference_steps: int, num_inference_steps: int,
height: int, height: int,
width: int, width: int,
guidance_scale: float, guidance_scale: float,
latents: Optional[jnp.array] = None, latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.ndarray] = None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -394,17 +394,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -394,17 +394,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.ndarray,
mask: jnp.array, mask: jnp.ndarray,
masked_image: jnp.array, masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.Array, prng_seed: jax.Array,
num_inference_steps: int = 50, num_inference_steps: int = 50,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5, guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.array = None, latents: jnp.ndarray = None,
neg_prompt_ids: jnp.array = None, neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True, return_dict: bool = True,
jit: bool = False, jit: bool = False,
): ):
...@@ -424,7 +424,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -424,7 +424,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*): latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`. array is generated by sampling using the supplied random `generator`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment