Unverified Commit fc639143 authored by Roy Hvaara's avatar Roy Hvaara Committed by GitHub
Browse files

[JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#26703)

`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html


so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.
Co-authored-by: default avatarPeter Hawkins <phawkins@google.com>
parent 3eceaa36
...@@ -256,7 +256,7 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -256,7 +256,7 @@ class FlaxRobertaSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
...@@ -258,7 +258,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module): ...@@ -258,7 +258,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat ...@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
""" """
Shift input ids one token to the right. Shift input ids one token to the right.
""" """
......
...@@ -266,7 +266,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module): ...@@ -266,7 +266,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
...@@ -251,7 +251,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): ...@@ -251,7 +251,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
key_value_states: Optional[jnp.array] = None, key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False, init_cache: bool = False,
deterministic=True, deterministic=True,
output_attentions: bool = False, output_attentions: bool = False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment