"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "c6eba8c0a1bd47a77cb225482383bd079cdaa66a"
Unverified Commit fc639143 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

[JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#26703)

`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html


so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.
Co-authored-by: default avatarPeter Hawkins <phawkins@google.com>
parent 3eceaa36
...@@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t ...@@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): ...@@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -389,7 +389,7 @@ def create_train_state( ...@@ -389,7 +389,7 @@ def create_train_state(
# region Create learning rate function # region Create learning rate function
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn( def create_learning_rate_fn(
num_train_steps: int, num_warmup_steps: int, learning_rate: float num_train_steps: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule( decay_fn = optax.linear_schedule(
......
...@@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -288,7 +288,7 @@ def create_train_state( ...@@ -288,7 +288,7 @@ def create_train_state(
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -340,7 +340,7 @@ def create_train_state( ...@@ -340,7 +340,7 @@ def create_train_state(
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): ...@@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): ...@@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def create_learning_rate_fn( def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]: ) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function.""" """Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs num_train_steps = steps_per_epoch * num_train_epochs
......
...@@ -217,7 +217,7 @@ BART_DECODE_INPUTS_DOCSTRING = r""" ...@@ -217,7 +217,7 @@ BART_DECODE_INPUTS_DOCSTRING = r"""
""" """
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -295,7 +295,7 @@ class FlaxBertSelfAttention(nn.Module): ...@@ -295,7 +295,7 @@ class FlaxBertSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
...@@ -316,7 +316,7 @@ class FlaxBigBirdSelfAttention(nn.Module): ...@@ -316,7 +316,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
...@@ -204,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r""" ...@@ -204,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r""" ...@@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -263,7 +263,7 @@ class FlaxElectraSelfAttention(nn.Module): ...@@ -263,7 +263,7 @@ class FlaxElectraSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -1228,13 +1228,13 @@ class FlaxElectraSequenceSummary(nn.Module): ...@@ -1228,13 +1228,13 @@ class FlaxElectraSequenceSummary(nn.Module):
Compute a single vector summary of a sequence hidden states. Compute a single vector summary of a sequence hidden states.
Args: Args:
hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`): hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer. The hidden states of the last layer.
cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
Returns: Returns:
`jnp.array`: The summary of the sequence hidden states. `jnp.ndarray`: The summary of the sequence hidden states.
""" """
# NOTE: this doest "first" type summary always # NOTE: this doest "first" type summary always
output = hidden_states[:, 0] output = hidden_states[:, 0]
......
...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat ...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim): ...@@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim):
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -27,7 +27,7 @@ _CONFIG_FOR_DOC = "T5Config" ...@@ -27,7 +27,7 @@ _CONFIG_FOR_DOC = "T5Config"
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -210,7 +210,7 @@ PEGASUS_DECODE_INPUTS_DOCSTRING = r""" ...@@ -210,7 +210,7 @@ PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment