Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
50595a33
Unverified
Commit
50595a33
authored
Apr 21, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 21, 2021
Browse files
Remove boiler plate code (#11340)
* remove boiler plate code * adapt roberta * correct docs * finish refactor
parent
ac588594
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
133 additions
and
363 deletions
+133
-363
src/transformers/file_utils.py
src/transformers/file_utils.py
+11
-1
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+21
-3
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/auto_factory.py
+1
-10
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+68
-310
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+32
-39
No files found.
src/transformers/file_utils.py
View file @
50595a33
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp.
https://github.com/allenai/allennlp.
"""
"""
import
copy
import
copy
import
fnmatch
import
fnmatch
import
functools
import
importlib.util
import
importlib.util
import
io
import
io
import
json
import
json
...
@@ -27,6 +27,7 @@ import shutil
...
@@ -27,6 +27,7 @@ import shutil
import
sys
import
sys
import
tarfile
import
tarfile
import
tempfile
import
tempfile
import
types
from
collections
import
OrderedDict
,
UserDict
from
collections
import
OrderedDict
,
UserDict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
fields
from
dataclasses
import
fields
...
@@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
...
@@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
def
_get_module
(
self
,
module_name
:
str
)
->
ModuleType
:
def
_get_module
(
self
,
module_name
:
str
)
->
ModuleType
:
raise
NotImplementedError
raise
NotImplementedError
def
copy_func
(
f
):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g
=
types
.
FunctionType
(
f
.
__code__
,
f
.
__globals__
,
name
=
f
.
__name__
,
argdefs
=
f
.
__defaults__
,
closure
=
f
.
__closure__
)
g
=
functools
.
update_wrapper
(
g
,
f
)
g
.
__kwdefaults__
=
f
.
__kwdefaults__
return
g
src/transformers/modeling_flax_utils.py
View file @
50595a33
...
@@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
...
@@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from
jax.random
import
PRNGKey
from
jax.random
import
PRNGKey
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
FLAX_WEIGHTS_NAME
,
WEIGHTS_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
from
.file_utils
import
(
FLAX_WEIGHTS_NAME
,
WEIGHTS_NAME
,
add_start_docstrings_to_model_forward
,
cached_path
,
copy_func
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
)
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
logging
from
.utils
import
logging
...
@@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
...
@@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
# randomely initialized parameters
# randomely initialized parameters
random_params
=
self
.
init
(
self
.
key
,
input_shape
)
random_params
=
self
.
init
_weights
(
self
.
key
,
input_shape
)
# save required_params as set
# save required_params as set
self
.
_required_params
=
set
(
flatten_dict
(
unfreeze
(
random_params
)).
keys
())
self
.
_required_params
=
set
(
flatten_dict
(
unfreeze
(
random_params
)).
keys
())
self
.
params
=
random_params
self
.
params
=
random_params
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
Dict
:
def
init
_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
Dict
:
raise
NotImplementedError
(
f
"init method has to be implemented for
{
self
}
"
)
raise
NotImplementedError
(
f
"init method has to be implemented for
{
self
}
"
)
@
property
@
property
...
@@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
...
@@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
with
open
(
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
),
"wb"
)
as
f
:
with
open
(
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
),
"wb"
)
as
f
:
model_bytes
=
to_bytes
(
self
.
params
)
model_bytes
=
to_bytes
(
self
.
params
)
f
.
write
(
model_bytes
)
f
.
write
(
model_bytes
)
def
overwrite_call_docstring
(
model_class
,
docstring
):
# copy __call__ function to be sure docstring is changed only for this function
model_class
.
__call__
=
copy_func
(
model_class
.
__call__
)
# delete existing docstring
model_class
.
__call__
.
__doc__
=
None
# set correct docstring
model_class
.
__call__
=
add_start_docstrings_to_model_forward
(
docstring
)(
model_class
.
__call__
)
src/transformers/models/auto/auto_factory.py
View file @
50595a33
...
@@ -14,10 +14,10 @@
...
@@ -14,10 +14,10 @@
# limitations under the License.
# limitations under the License.
"""Factory function to build auto-model classes."""
"""Factory function to build auto-model classes."""
import
functools
import
types
import
types
from
...configuration_utils
import
PretrainedConfig
from
...configuration_utils
import
PretrainedConfig
from
...file_utils
import
copy_func
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
from
.configuration_auto
import
AutoConfig
,
replace_list_option_in_docstrings
...
@@ -385,15 +385,6 @@ class _BaseAutoModelClass:
...
@@ -385,15 +385,6 @@ class _BaseAutoModelClass:
)
)
def
copy_func
(
f
):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g
=
types
.
FunctionType
(
f
.
__code__
,
f
.
__globals__
,
name
=
f
.
__name__
,
argdefs
=
f
.
__defaults__
,
closure
=
f
.
__closure__
)
g
=
functools
.
update_wrapper
(
g
,
f
)
g
.
__kwdefaults__
=
f
.
__kwdefaults__
return
g
def
insert_head_doc
(
docstring
,
head_doc
=
""
):
def
insert_head_doc
(
docstring
,
head_doc
=
""
):
if
len
(
head_doc
)
>
0
:
if
len
(
head_doc
)
>
0
:
return
docstring
.
replace
(
return
docstring
.
replace
(
...
...
src/transformers/models/bert/modeling_flax_bert.py
View file @
50595a33
...
@@ -26,7 +26,7 @@ from jax import lax
...
@@ -26,7 +26,7 @@ from jax import lax
from
jax.random
import
PRNGKey
from
jax.random
import
PRNGKey
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...file_utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_flax_utils
import
ACT2FN
,
FlaxPreTrainedModel
from
...modeling_flax_utils
import
ACT2FN
,
FlaxPreTrainedModel
,
overwrite_call_docstring
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_bert
import
BertConfig
from
.configuration_bert
import
BertConfig
...
@@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
...
@@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]``.
config.max_position_embeddings - 1]``.
return_dict (:obj:`bool`, `optional`):
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
"""
...
@@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
...
@@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
config_class
=
BertConfig
config_class
=
BertConfig
base_model_prefix
=
"bert"
base_model_prefix
=
"bert"
module_class
:
nn
.
Module
=
None
def
_check_inputs
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
):
def
__init__
(
if
token_type_ids
is
None
:
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
if
position_ids
is
None
:
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
# init input tensors
input_ids
=
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
position_ids
=
jnp
.
arange
(
jnp
.
atleast_2d
(
input_ids
).
shape
[
-
1
])
position_ids
=
jnp
.
arange
(
jnp
.
atleast_2d
(
input_ids
).
shape
[
-
1
])
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
attention_mask
=
jnp
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
),
None
,
None
,
None
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
@
add_start_docstrings
(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
BERT_START_DOCSTRING
,
)
class
FlaxBertModel
(
FlaxBertPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertModule
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
...
@@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
dropout_rng
:
PRNGKey
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
train
:
bool
=
False
,
):
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
# init input tensors if not passed
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
if
token_type_ids
is
None
:
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
jnp
.
arange
(
jnp
.
atleast_2d
(
input_ids
).
shape
[
-
1
])
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
# Handle any PRNG if needed
# Handle any PRNG if needed
rngs
=
{}
rngs
=
{}
...
@@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
...
@@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
"""
,
BERT_START_DOCSTRING
,
BERT_START_DOCSTRING
,
)
)
class
FlaxBertForPreTraining
(
FlaxBertPreTrainedModel
):
class
FlaxBertModel
(
FlaxBertPreTrainedModel
):
def
__init__
(
module_class
=
FlaxBertModule
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForPreTrainingModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForPreTrainingModule
(
nn
.
Module
):
class
FlaxBertForPreTrainingModule
(
nn
.
Module
):
...
@@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
...
@@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
return
(
prediction_scores
,
seq_relationship_score
)
return
(
prediction_scores
,
seq_relationship_score
)
@
add_start_docstrings
(
"""Bert Model with a `language modeling` head on top. """
,
BERT_START_DOCSTRING
)
@
add_start_docstrings
(
class
FlaxBertForMaskedLM
(
FlaxBertPreTrainedModel
):
"""
def
__init__
(
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
sentence prediction (classification)` head.
):
"""
,
module
=
FlaxBertForMaskedLMModule
(
config
,
**
kwargs
)
BERT_START_DOCSTRING
,
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
class
FlaxBertForPreTraining
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForPreTrainingModule
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForMaskedLMModule
(
nn
.
Module
):
class
FlaxBertForMaskedLMModule
(
nn
.
Module
):
...
@@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
...
@@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
return
(
logits
,)
return
(
logits
,)
@
add_start_docstrings
(
@
add_start_docstrings
(
"""Bert Model with a `language modeling` head on top. """
,
BERT_START_DOCSTRING
)
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
class
FlaxBertForMaskedLM
(
FlaxBertPreTrainedModel
):
BERT_START_DOCSTRING
,
module_class
=
FlaxBertForMaskedLMModule
)
class
FlaxBertForNextSentencePrediction
(
FlaxBertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForNextSentencePredictionModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForNextSentencePredictionModule
(
nn
.
Module
):
class
FlaxBertForNextSentencePredictionModule
(
nn
.
Module
):
...
@@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
...
@@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
"""
,
BERT_START_DOCSTRING
,
BERT_START_DOCSTRING
,
)
)
class
FlaxBertForSequenceClassification
(
FlaxBertPreTrainedModel
):
class
FlaxBertForNextSentencePrediction
(
FlaxBertPreTrainedModel
):
def
__init__
(
module_class
=
FlaxBertForNextSentencePredictionModule
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForSequenceClassificationModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForSequenceClassificationModule
(
nn
.
Module
):
class
FlaxBertForSequenceClassificationModule
(
nn
.
Module
):
...
@@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
...
@@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
Bert Model
with a multiple choi
ce classification head on top (a linear layer on top of the pooled
output and a
Bert Model
transformer with a sequen
ce classification
/regression
head on top (a linear layer on top of the pooled
softmax
) e.g. for
RocStories/SWAG
tasks.
output
) e.g. for
GLUE
tasks.
"""
,
"""
,
BERT_START_DOCSTRING
,
BERT_START_DOCSTRING
,
)
)
class
FlaxBertForMultipleChoice
(
FlaxBertPreTrainedModel
):
class
FlaxBertForSequenceClassification
(
FlaxBertPreTrainedModel
):
def
__init__
(
module_class
=
FlaxBertForSequenceClassificationModule
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForMultipleChoiceModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForMultipleChoiceModule
(
nn
.
Module
):
class
FlaxBertForMultipleChoiceModule
(
nn
.
Module
):
...
@@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
...
@@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
Bert Model with a
token
classification head on top (a linear layer on top of the
hidden-states output) e.g. for
Bert Model with a
multiple choice
classification head on top (a linear layer on top of the
pooled output and a
Named-Entity-Recognition (NER)
tasks.
softmax) e.g. for RocStories/SWAG
tasks.
"""
,
"""
,
BERT_START_DOCSTRING
,
BERT_START_DOCSTRING
,
)
)
class
FlaxBertForTokenClassification
(
FlaxBertPreTrainedModel
):
class
FlaxBertForMultipleChoice
(
FlaxBertPreTrainedModel
):
def
__init__
(
module_class
=
FlaxBertForMultipleChoiceModule
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForTokenClassificationModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
# adapt docstring slightly for FlaxBertForMultipleChoice
{
"params"
:
params
or
self
.
params
},
overwrite_call_docstring
(
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
FlaxBertForMultipleChoice
,
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, num_choices, sequence_length"
)
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
)
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForTokenClassificationModule
(
nn
.
Module
):
class
FlaxBertForTokenClassificationModule
(
nn
.
Module
):
...
@@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
...
@@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
Bert Model with a
spa
n classification head on top
for extractive question-answering tasks like SQuAD (a linea
r
Bert Model with a
toke
n classification head on top
(a linear layer on top of the hidden-states output) e.g. fo
r
layers on top of the hidden-states output to compute `span start logits` and `span end logits`)
.
Named-Entity-Recognition (NER) tasks
.
"""
,
"""
,
BERT_START_DOCSTRING
,
BERT_START_DOCSTRING
,
)
)
class
FlaxBertForQuestionAnswering
(
FlaxBertPreTrainedModel
):
class
FlaxBertForTokenClassification
(
FlaxBertPreTrainedModel
):
def
__init__
(
module_class
=
FlaxBertForTokenClassificationModule
self
,
config
:
BertConfig
,
input_shape
:
Tuple
=
(
1
,
1
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
):
module
=
FlaxBertForQuestionAnsweringModule
(
config
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
@
add_start_docstrings_to_model_forward
(
BERT_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)
# Handle any PRNG if needed
rngs
=
{}
if
dropout_rng
is
not
None
:
rngs
[
"dropout"
]
=
dropout_rng
return
self
.
module
.
apply
(
{
"params"
:
params
or
self
.
params
},
jnp
.
array
(
input_ids
,
dtype
=
"i4"
),
jnp
.
array
(
attention_mask
,
dtype
=
"i4"
),
jnp
.
array
(
token_type_ids
,
dtype
=
"i4"
),
jnp
.
array
(
position_ids
,
dtype
=
"i4"
),
not
train
,
rngs
=
rngs
,
)
class
FlaxBertForQuestionAnsweringModule
(
nn
.
Module
):
class
FlaxBertForQuestionAnsweringModule
(
nn
.
Module
):
...
@@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
...
@@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
return
(
start_logits
,
end_logits
)
return
(
start_logits
,
end_logits
)
@
add_start_docstrings
(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
,
BERT_START_DOCSTRING
,
)
class
FlaxBertForQuestionAnswering
(
FlaxBertPreTrainedModel
):
module_class
=
FlaxBertForQuestionAnsweringModule
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
50595a33
...
@@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
...
@@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class
=
RobertaConfig
config_class
=
RobertaConfig
base_model_prefix
=
"roberta"
base_model_prefix
=
"roberta"
def
init
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
module_class
:
nn
.
Module
=
None
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
),
None
,
None
,
None
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
def
_check_inputs
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
):
if
token_type_ids
is
None
:
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
@
add_start_docstrings
(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top."
,
ROBERTA_START_DOCSTRING
,
)
class
FlaxRobertaModel
(
FlaxRobertaPreTrainedModel
):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
...
@@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
**
kwargs
):
):
module
=
FlaxRobertaModule
(
config
,
dtype
=
dtype
,
**
kwargs
)
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
)
->
FrozenDict
:
# init input tensors
input_ids
=
jnp
.
zeros
(
input_shape
,
dtype
=
"i4"
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
attention_mask
=
jnp
.
ones_like
(
input_ids
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
module
.
init
(
rngs
,
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
)[
"params"
]
@
add_start_docstrings_to_model_forward
(
ROBERTA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_start_docstrings_to_model_forward
(
ROBERTA_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
def
__call__
(
def
__call__
(
self
,
self
,
input_ids
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
position_ids
=
None
,
params
:
dict
=
None
,
params
:
dict
=
None
,
dropout_rng
:
PRNGKey
=
None
,
dropout_rng
:
PRNGKey
=
None
,
train
:
bool
=
False
,
train
:
bool
=
False
,
):
):
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
=
self
.
_check_inputs
(
# init input tensors if not passed
input_ids
,
attention_mask
,
token_type_ids
,
position_ids
if
token_type_ids
is
None
:
)
token_type_ids
=
jnp
.
ones_like
(
input_ids
)
if
position_ids
is
None
:
position_ids
=
create_position_ids_from_input_ids
(
input_ids
,
self
.
config
.
pad_token_id
)
if
attention_mask
is
None
:
attention_mask
=
jnp
.
ones_like
(
input_ids
)
# Handle any PRNG if needed
# Handle any PRNG if needed
rngs
=
{}
rngs
=
{}
...
@@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
...
@@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
pooled
=
self
.
pooler
(
hidden_states
)
pooled
=
self
.
pooler
(
hidden_states
)
return
hidden_states
,
pooled
return
hidden_states
,
pooled
@
add_start_docstrings
(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top."
,
ROBERTA_START_DOCSTRING
,
)
class
FlaxRobertaModel
(
FlaxRobertaPreTrainedModel
):
module_class
=
FlaxRobertaModule
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment