Unverified Commit a6e6b1c6 authored by mariecwhite's avatar mariecwhite Committed by GitHub
Browse files

Remove jnp.DeviceArray since it is deprecated. (#24875)



* Remove jnp.DeviceArray since it is deprecated.

* Replace all instances of jnp.DeviceArray with jax.Array

* Update src/transformers/models/bert/modeling_flax_bert.py

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent fdd81aea
......@@ -1467,8 +1467,8 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......@@ -1960,7 +1960,7 @@ class FlaxBartForCausalLMModule(nn.Module):
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
module_class = FlaxBartForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1677,7 +1677,7 @@ class FlaxBertForCausalLMModule(nn.Module):
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -2599,7 +2599,7 @@ class FlaxBigBirdForCausalLMModule(nn.Module):
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
module_class = FlaxBigBirdForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1443,8 +1443,8 @@ class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1441,8 +1441,8 @@ class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedM
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1565,7 +1565,7 @@ class FlaxElectraForCausalLMModule(nn.Module):
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -722,8 +722,8 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -742,7 +742,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
module_class = FlaxGPT2LMHeadModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -654,7 +654,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
module_class = FlaxGPTNeoForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -683,7 +683,7 @@ class FlaxGPTJForCausalLMModule(nn.Module):
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
module_class = FlaxGPTJForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -2388,8 +2388,8 @@ class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1436,8 +1436,8 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1502,8 +1502,8 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -763,7 +763,7 @@ class FlaxOPTForCausalLMModule(nn.Module):
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
module_class = FlaxOPTForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1450,8 +1450,8 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1452,7 +1452,7 @@ class FlaxRobertaForCausalLMModule(nn.Module):
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -1478,7 +1478,7 @@ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module):
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
module_class = FlaxRobertaPreLayerNormForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
......
......@@ -745,8 +745,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -1740,8 +1740,8 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
......@@ -688,7 +688,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment