Unverified Commit 4ac205e3 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

[JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#4719)

`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html


so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.
Co-authored-by: default avatarPeter Hawkins <phawkins@google.com>
parent ed2f9560
......@@ -238,14 +238,14 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
def _generate(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
controlnet_conditioning_scale: float = 1.0,
):
height, width = image.shape[-2:]
......@@ -348,15 +348,15 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
return_dict: bool = True,
jit: bool = False,
):
......@@ -364,13 +364,13 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
The call function to the pipeline for generation.
Args:
prompt_ids (`jnp.array`):
prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide the image generation.
image (`jnp.array`):
image (`jnp.ndarray`):
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.Array` or `jax.Array`):
prng_seed (`jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
......@@ -378,11 +378,11 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`.
return_dict (`bool`, *optional*, defaults to `True`):
......
......@@ -220,8 +220,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -316,9 +316,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
......@@ -338,7 +338,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
......
......@@ -232,8 +232,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
def _generate(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
start_timestep: int,
......@@ -241,8 +241,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
height: int,
width: int,
guidance_scale: float,
noise: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
noise: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -337,17 +337,17 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
image: jnp.array,
prompt_ids: jnp.ndarray,
image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
noise: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
noise: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
......@@ -355,9 +355,9 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
The call function to the pipeline for generation.
Args:
prompt_ids (`jnp.array`):
prompt_ids (`jnp.ndarray`):
The prompt or prompts to guide image generation.
image (`jnp.array`):
image (`jnp.ndarray`):
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
......@@ -379,7 +379,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
noise (`jnp.array`, *optional*):
noise (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. The array is generated by
sampling using the supplied random `generator`.
......
......@@ -266,17 +266,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
def _generate(
self,
prompt_ids: jnp.array,
mask: jnp.array,
masked_image: jnp.array,
prompt_ids: jnp.ndarray,
mask: jnp.ndarray,
masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
latents: Optional[jnp.ndarray] = None,
neg_prompt_ids: Optional[jnp.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
......@@ -394,17 +394,17 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.array,
mask: jnp.array,
masked_image: jnp.array,
prompt_ids: jnp.ndarray,
mask: jnp.ndarray,
masked_image: jnp.ndarray,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
guidance_scale: Union[float, jnp.ndarray] = 7.5,
latents: jnp.ndarray = None,
neg_prompt_ids: jnp.ndarray = None,
return_dict: bool = True,
jit: bool = False,
):
......@@ -424,7 +424,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.array`, *optional*):
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment