Unverified Commit 57b6a80d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Fix BERT initialization & token_type_ids default (#11695)



* fix some stuff

* fix roberta & electra as well

* del run bug
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent daf0d6a9
......@@ -558,7 +558,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
......@@ -587,7 +589,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
......
......@@ -502,14 +502,16 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
......
......@@ -546,7 +546,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
......@@ -575,7 +577,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment