Unverified Commit 623281aa authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax BERT/Roberta] few small fixes (#11558)

* small fixes

* style
parent a5d2967b
...@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla ...@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from jax import lax from jax import lax
from jax.random import PRNGKey
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
...@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r""" ...@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models) PyTorch models)
This model is also a Flax Linen `flax.nn.Module This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax <https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
Module and refer to the Flax documentation for all matter related to general usage and behavior. and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as: Finally, this model supports inherent JAX features such as:
...@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r""" ...@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r"""
Parameters: Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
weights. model weights.
""" """
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
...@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed # Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4")) inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4"))
...@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module): ...@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module):
# Sum all embeddings # Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm # Layer Norm
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
params: dict = None, params: dict = None,
dropout_rng: PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
......
...@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r""" ...@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models) PyTorch models)
This model is also a Flax Linen `flax.nn.Module This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax <https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
Module and refer to the Flax documentation for all matter related to general usage and behavior. and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as: Finally, this model supports inherent JAX features such as:
...@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r""" ...@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r"""
Parameters: Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
weights. model weights.
""" """
ROBERTA_INPUTS_DOCSTRING = r""" ROBERTA_INPUTS_DOCSTRING = r"""
...@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed # Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4")) inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4"))
...@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module): ...@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module):
# Sum all embeddings # Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm # Layer Norm
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment