Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
623281aa
Unverified
Commit
623281aa
authored
May 03, 2021
by
Suraj Patil
Committed by
GitHub
May 03, 2021
Browse files
[Flax BERT/Roberta] few small fixes (#11558)
* small fixes * style
parent
a5d2967b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
16 deletions
+11
-16
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+6
-9
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+5
-7
No files found.
src/transformers/models/bert/modeling_flax_bert.py
View file @
623281aa
...
...
@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla
from
flax.core.frozen_dict
import
FrozenDict
from
flax.linen
import
dot_product_attention
from
jax
import
lax
from
jax.random
import
PRNGKey
from
...file_utils
import
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_flax_outputs
import
(
...
...
@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.
n
n.Module
<https://flax.readthedocs.io/en/latest/
_autosummary/flax.nn.
module
.html
>`__ subclass. Use it as a regular Flax
Module
and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.
line
n.Module
<https://flax.readthedocs.io/en/latest/
flax.linen.html#
module>`__ subclass. Use it as a regular Flax
linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
...
...
@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the
model
weights.
configuration. Check out the :meth:`~transformers.
Flax
PreTrainedModel.from_pretrained` method to load the
model
weights.
"""
BERT_INPUTS_DOCSTRING
=
r
"""
...
...
@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
batch_size
,
sequence_length
=
input_ids
.
shape
# Embed
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
...
...
@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module):
# Sum all embeddings
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
...
...
@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
dropout_rng
:
jax
.
random
.
PRNGKey
=
None
,
train
:
bool
=
False
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
623281aa
...
...
@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.
n
n.Module
<https://flax.readthedocs.io/en/latest/
_autosummary/flax.nn.
module
.html
>`__ subclass. Use it as a regular Flax
Module
and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.
line
n.Module
<https://flax.readthedocs.io/en/latest/
flax.linen.html#
module>`__ subclass. Use it as a regular Flax
linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
...
...
@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r"""
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the
model
weights.
configuration. Check out the :meth:`~transformers.
Flax
PreTrainedModel.from_pretrained` method to load the
model
weights.
"""
ROBERTA_INPUTS_DOCSTRING
=
r
"""
...
...
@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
batch_size
,
sequence_length
=
input_ids
.
shape
# Embed
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
...
...
@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module):
# Sum all embeddings
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment