Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
50595a33
Unverified
Commit
50595a33
authored
Apr 21, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 21, 2021
Browse files
Remove boiler plate code (#11340)
* remove boiler plate code * adapt roberta * correct docs * finish refactor
parent
ac588594
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
133 additions
and
363 deletions
+133
-363
src/transformers/file_utils.py
src/transformers/file_utils.py
+11
-1
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+21
-3
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+1
-10
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+68
-310
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+32
-39
No files found.
src/transformers/file_utils.py
View file @
50595a33
...
...
@@ -15,9 +15,9 @@
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp.
"""
import
copy
import
fnmatch
import
functools
import
importlib.util
import
io
import
json
...
...
@@ -27,6 +27,7 @@ import shutil
import
sys
import
tarfile
import
tempfile
import
types
from
collections
import
OrderedDict
,
UserDict
from
contextlib
import
contextmanager
from
dataclasses
import
fields
...
...
@@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
def
_get_module
(
self
,
module_name
:
str
)
->
ModuleType
:
raise
NotImplementedError
def
copy_func
(
f
):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g
=
types
.
FunctionType
(
f
.
__code__
,
f
.
__globals__
,
name
=
f
.
__name__
,
argdefs
=
f
.
__defaults__
,
closure
=
f
.
__closure__
)
g
=
functools
.
update_wrapper
(
g
,
f
)
g
.
__kwdefaults__
=
f
.
__kwdefaults__
return
g
src/transformers/modeling_flax_utils.py
View file @
50595a33
...
...
@@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from
jax.random
import
PRNGKey
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
FLAX_WEIGHTS_NAME
,
WEIGHTS_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
from
.file_utils
import
(
FLAX_WEIGHTS_NAME
,
WEIGHTS_NAME
,
add_start_docstrings_to_model_forward
,
cached_path
,
copy_func
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
)
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
logging
...
...
@@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
self
.
dtype
=
dtype
# randomely initialized parameters
random_params
=
self
.
init
(
self
.
key
,
input_shape
)
random_params
=
self
.
init
_weights
(
self
.
key
,
input_shape
)
# save required_params as set
self
.
_required_params
=
set
(
flatten_dict
(
unfreeze
(
random_params
)).
keys
())
self
.
params
=
random_params
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
Dict
:
def
init
_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
Dict
:
raise
NotImplementedError
(
f
"init method has to be implemented for
{
self
}
"
)
@
property
...
...
@@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
with
open
(
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
),
"wb"
)
as
f
:
model_bytes
=
to_bytes
(
self
.
params
)
f
.
write
(
model_bytes
)
def
overwrite_call_docstring
(
model_class
,
docstring
):
# copy __call__ function to be sure docstring is changed only for this function
model_class
.
__call__
=
copy_func
(
model_class
.
__call__
)
# delete existing docstring
model_class
.
__call__
.
__doc__
=
None
# set correct docstring
model_class
.
__call__
=
add_start_docstrings_to_model_forward
(
docstring
)(
model_class
.
__call__
)
src/transformers/models/auto/auto_factory.py
View file @
50595a33
...
...
@@ -14,10 +14,10 @@
# limitations under the License.
"""Factory function to build auto-model classes."""
import
functools
import
types
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
copy_func
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
...
...
@@ -385,15 +385,6 @@ class _BaseAutoModelClass:
)
def
copy_func
(
f
):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g
=
types
.
FunctionType
(
f
.
__code__
,
f
.
__globals__
,
name
=
f
.
__name__
,
argdefs
=
f
.
__defaults__
,
closure
=
f
.
__closure__
)
g
=
functools
.
update_wrapper
(
g
,
f
)
g
.
__kwdefaults__
=
f
.
__kwdefaults__
return
g
def
insert_head_doc
(
docstring
,
head_doc
=
""
):
if
len
(
head_doc
)
>
0
:
return
docstring
.
replace
(
...
...
src/transformers/models/bert/modeling_flax_bert.py
View file @
50595a33
...
...
@@ -26,7 +26,7 @@ from jax import lax
from
jax.random
import
PRNGKey
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_flax_utils
import
ACT2FN
,
FlaxPreTrainedModel
from
...modeling_flax_utils
import
ACT2FN
,
FlaxPreTrainedModel
,
overwrite_call_docstring
from
...utils
import
logging
from
.configuration_bert
import
BertConfig
...
...
@@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]``.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
...
...
@@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
config_class
=
BertConfig
base_model_prefix
=
"bert"
module_class
:
nn
.
Module
=
None
def
_check_inputs
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
):
if
token_type_ids
is
None
:
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
if
position_ids
is
None
:
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
# init input tensors
input_ids
=
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
position_ids
=
jnp
.
arange
(
jnp
.
atleast_2d
(
input_ids
).
shape
[
-
1
])
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
),
None
,
None
,
None
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
@
add_start_docstrings
(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
BERT_START_DOCSTRING
,
)
class
FlaxBertModel
(
FlaxBertPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertModule
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
...
...
@@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# init input tensors if not passed
if
token_type_ids
is
None
:
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
jnp
.
arange
(
jnp
.
atleast_2d
(
input_ids
).
shape
[
-
1
])
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
# Handle any PRNG if needed
rngs
=
{}
...
...
@@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
@
add_start_docstrings
(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
"""
,
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForPreTraining
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForPreTrainingModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertModel
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertModule
class
FlaxBertForPreTrainingModule
(
nn
.
Module
):
...
...
@@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
return
(
prediction_scores
,
seq_relationship_score
)
@
add_start_docstrings
(
"""Bert Model with a `language modeling` head on top. """
,
BERT_START_DOCSTRING
)
class
FlaxBertForMaskedLM
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForMaskedLMModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
@
add_start_docstrings
(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForPreTraining
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForPreTrainingModule
class
FlaxBertForMaskedLMModule
(
nn
.
Module
):
...
...
@@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
return
(
logits
,)
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForNextSentencePrediction
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForNextSentencePredictionModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
@
add_start_docstrings
(
"""Bert Model with a `language modeling` head on top. """
,
BERT_START_DOCSTRING
)
class
FlaxBertForMaskedLM
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForMaskedLMModule
class
FlaxBertForNextSentencePredictionModule
(
nn
.
Module
):
...
...
@@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
@
add_start_docstrings
(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
"""
,
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForSequenceClassification
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForSequenceClassificationModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForNextSentencePrediction
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForNextSentencePredictionModule
class
FlaxBertForSequenceClassificationModule
(
nn
.
Module
):
...
...
@@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
@
add_start_docstrings
(
"""
Bert Model
with a multiple choi
ce classification head on top (a linear layer on top of the pooled
output and a
softmax
) e.g. for
RocStories/SWAG
tasks.
Bert Model
transformer with a sequen
ce classification
/regression
head on top (a linear layer on top of the pooled
output
) e.g. for
GLUE
tasks.
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForMultipleChoice
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForMultipleChoiceModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForSequenceClassification
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForSequenceClassificationModule
class
FlaxBertForMultipleChoiceModule
(
nn
.
Module
):
...
...
@@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
@
add_start_docstrings
(
"""
Bert Model with a
token
classification head on top (a linear layer on top of the
hidden-states output) e.g. for
Named-Entity-Recognition (NER)
tasks.
Bert Model with a
multiple choice
classification head on top (a linear layer on top of the
pooled output and a
softmax) e.g. for RocStories/SWAG
tasks.
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForTokenClassification
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForTokenClassificationModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
class
FlaxBertForMultipleChoice
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForMultipleChoiceModule
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
# adapt docstring slightly for FlaxBertForMultipleChoice
overwrite_call_docstring
(
FlaxBertForMultipleChoice
,
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
)
)
class
FlaxBertForTokenClassificationModule
(
nn
.
Module
):
...
...
@@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
@
add_start_docstrings
(
"""
Bert Model with a
spa
n classification head on top
for extractive question-answering tasks like SQuAD (a linea
r
layers on top of the hidden-states output to compute `span start logits` and `span end logits`)
.
Bert Model with a
toke
n classification head on top
(a linear layer on top of the hidden-states output) e.g. fo
r
Named-Entity-Recognition (NER) tasks
.
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForQuestionAnswering
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForQuestionAnsweringModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForTokenClassification
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForTokenClassificationModule
class
FlaxBertForQuestionAnsweringModule
(
nn
.
Module
):
...
...
@@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
end_logits
=
end_logits
.
squeeze
(
-
1
)
return
(
start_logits
,
end_logits
)
@
add_start_docstrings
(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForQuestionAnswering
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForQuestionAnsweringModule
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
50595a33
...
...
@@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class
=
RobertaConfig
base_model_prefix
=
"roberta"
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
),
None
,
None
,
None
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
def
_check_inputs
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
):
if
token_type_ids
is
None
:
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
@
add_start_docstrings
(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top."
,
ROBERTA_START_DOCSTRING
,
)
class
FlaxRobertaModel
(
FlaxRobertaPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
module_class
:
nn
.
Module
=
None
def
__init__
(
self
,
...
...
@@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxRobertaModule
(
config
,
dtype
=
dtype
,
**
kwargs
)
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
# init input tensors
input_ids
=
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
attention_mask
=
jnp
.
ones_like
(
input_ids
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
@
add_start_docstrings_to_model_forward
(
ROBERTA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# init input tensors if not passed
if
token_type_ids
is
None
:
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
# Handle any PRNG if needed
rngs
=
{}
...
...
@@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
pooled
=
self
.
pooler
(
hidden_states
)
return
hidden_states
,
pooled
@
add_start_docstrings
(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top."
,
ROBERTA_START_DOCSTRING
,
)
class
FlaxRobertaModel
(
FlaxRobertaPreTrainedModel
):
module_class
=
FlaxRobertaModule
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment