Unverified Commit a6e6b1c6 authored by mariecwhite's avatar mariecwhite Committed by GitHub
Browse files

Remove jnp.DeviceArray since it is deprecated. (#24875)



* Remove jnp.DeviceArray since it is deprecated.

* Replace all instances of jnp.DeviceArray with jax.Array

* Update src/transformers/models/bert/modeling_flax_bert.py

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent fdd81aea
...@@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): ...@@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
self, self,
decoder_input_ids, decoder_input_ids,
max_length, max_length,
attention_mask: Optional[jnp.DeviceArray] = None, attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None, encoder_outputs=None,
**kwargs, **kwargs,
): ):
......
...@@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module): ...@@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module):
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel): class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMForCausalLMModule module_class = FlaxXGLMForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache # initializing the cache
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
......
...@@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module): ...@@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module):
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel): class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
module_class = FlaxXLMRobertaForCausalLMModule module_class = FlaxXLMRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache # initializing the cache
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
......
...@@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): ...@@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel): class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache # initializing the cache
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
...@@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo ...@@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
self, self,
decoder_input_ids, decoder_input_ids,
max_length, max_length,
attention_mask: Optional[jnp.DeviceArray] = None, attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None, encoder_outputs=None,
**kwargs **kwargs
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment