"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b95dfcf1103049212855261144c906dd140a4db5"
Unverified Commit 57b6a80d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Fix BERT initialization & token_type_ids default (#11695)



* fix some stuff

* fix roberta & electra as well

* del run bug
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent daf0d6a9
...@@ -558,7 +558,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -558,7 +558,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
...@@ -587,7 +589,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -587,7 +589,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None: if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
......
...@@ -502,14 +502,16 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ...@@ -502,14 +502,16 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors # init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
......
...@@ -546,7 +546,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -546,7 +546,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
"params"
]
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
...@@ -575,7 +577,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -575,7 +577,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.zeros_like(input_ids)
if position_ids is None: if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment