Unverified Commit a6e6b1c6 authored by mariecwhite's avatar mariecwhite Committed by GitHub
Browse files

Remove jnp.DeviceArray since it is deprecated. (#24875)



* Remove jnp.DeviceArray since it is deprecated.

* Replace all instances of jnp.DeviceArray with jax.Array

* Update src/transformers/models/bert/modeling_flax_bert.py

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent fdd81aea
......@@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module):
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module):
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
module_class = FlaxXLMRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......@@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment