"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0afe4a90f9d9a9ea811b3343427cf3dd3bc26ad8"
Unverified Commit 623281aa authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax BERT/Roberta] few small fixes (#11558)

* small fixes

* style
parent a5d2967b
......@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
......@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
......@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
......@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
......@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module):
# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
......@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......
......@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
......@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
ROBERTA_INPUTS_DOCSTRING = r"""
......@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
......@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module):
# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment