Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
623281aa
Unverified
Commit
623281aa
authored
May 03, 2021
by
Suraj Patil
Committed by
GitHub
May 03, 2021
Browse files
[Flax BERT/Roberta] few small fixes (#11558)
* small fixes * style
parent
a5d2967b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
16 deletions
+11
-16
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+6
-9
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+5
-7
No files found.
src/transformers/models/bert/modeling_flax_bert.py
View file @
623281aa
...
@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla
...
@@ -25,7 +25,6 @@ import jaxlib.xla_extension as jax_xla
from
flax.core.frozen_dict
import
FrozenDict
from
flax.core.frozen_dict
import
FrozenDict
from
flax.linen
import
dot_product_attention
from
flax.linen
import
dot_product_attention
from
jax
import
lax
from
jax
import
lax
from
jax.random
import
PRNGKey
from
...file_utils
import
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...file_utils
import
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_flax_outputs
import
(
from
...modeling_flax_outputs
import
(
...
@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r"""
...
@@ -92,9 +91,9 @@ BERT_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
PyTorch models)
This model is also a Flax Linen `flax.
n
n.Module
This model is also a Flax Linen `flax.
line
n.Module
<https://flax.readthedocs.io/en/latest/
_autosummary/flax.nn.
module
.html
>`__ subclass. Use it as a regular Flax
<https://flax.readthedocs.io/en/latest/
flax.linen.html#
module>`__ subclass. Use it as a regular Flax
linen Module
Module
and refer to the Flax documentation for all matter related to general usage and behavior.
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
Finally, this model supports inherent JAX features such as:
...
@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r"""
...
@@ -106,8 +105,8 @@ BERT_START_DOCSTRING = r"""
Parameters:
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the
model
configuration. Check out the :meth:`~transformers.
Flax
PreTrainedModel.from_pretrained` method to load the
weights.
model
weights.
"""
"""
BERT_INPUTS_DOCSTRING
=
r
"""
BERT_INPUTS_DOCSTRING
=
r
"""
...
@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module):
...
@@ -173,7 +172,6 @@ class FlaxBertEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
batch_size
,
sequence_length
=
input_ids
.
shape
# Embed
# Embed
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
...
@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module):
...
@@ -181,7 +179,6 @@ class FlaxBertEmbeddings(nn.Module):
# Sum all embeddings
# Sum all embeddings
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
# Layer Norm
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
...
@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
...
@@ -571,7 +568,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
token_type_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
dropout_rng
:
jax
.
random
.
PRNGKey
=
None
,
train
:
bool
=
False
,
train
:
bool
=
False
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
623281aa
...
@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r"""
...
@@ -59,9 +59,9 @@ ROBERTA_START_DOCSTRING = r"""
generic methods the library implements for all its model (such as downloading, saving and converting weights from
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
PyTorch models)
This model is also a Flax Linen `flax.
n
n.Module
This model is also a Flax Linen `flax.
line
n.Module
<https://flax.readthedocs.io/en/latest/
_autosummary/flax.nn.
module
.html
>`__ subclass. Use it as a regular Flax
<https://flax.readthedocs.io/en/latest/
flax.linen.html#
module>`__ subclass. Use it as a regular Flax
linen Module
Module
and refer to the Flax documentation for all matter related to general usage and behavior.
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
Finally, this model supports inherent JAX features such as:
...
@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r"""
...
@@ -73,8 +73,8 @@ ROBERTA_START_DOCSTRING = r"""
Parameters:
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the
model
configuration. Check out the :meth:`~transformers.
Flax
PreTrainedModel.from_pretrained` method to load the
weights.
model
weights.
"""
"""
ROBERTA_INPUTS_DOCSTRING
=
r
"""
ROBERTA_INPUTS_DOCSTRING
=
r
"""
...
@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module):
...
@@ -140,7 +140,6 @@ class FlaxRobertaEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
rate
=
self
.
config
.
hidden_dropout_prob
)
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
def
__call__
(
self
,
input_ids
,
token_type_ids
,
position_ids
,
attention_mask
,
deterministic
:
bool
=
True
):
batch_size
,
sequence_length
=
input_ids
.
shape
# Embed
# Embed
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
inputs_embeds
=
self
.
word_embeddings
(
input_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
position_embeds
=
self
.
position_embeddings
(
position_ids
.
astype
(
"i4"
))
...
@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module):
...
@@ -148,7 +147,6 @@ class FlaxRobertaEmbeddings(nn.Module):
# Sum all embeddings
# Sum all embeddings
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
hidden_states
=
inputs_embeds
+
token_type_embeddings
+
position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
# Layer Norm
# Layer Norm
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment