Unverified Commit 50595a33 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Remove boiler plate code (#11340)

* remove boiler plate code

* adapt roberta

* correct docs

* finish refactor
parent ac588594
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
https://github.com/allenai/allennlp. https://github.com/allenai/allennlp.
""" """
import copy import copy
import fnmatch import fnmatch
import functools
import importlib.util import importlib.util
import io import io
import json import json
...@@ -27,6 +27,7 @@ import shutil ...@@ -27,6 +27,7 @@ import shutil
import sys import sys
import tarfile import tarfile
import tempfile import tempfile
import types
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import fields from dataclasses import fields
...@@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType): ...@@ -1674,3 +1675,12 @@ class _BaseLazyModule(ModuleType):
def _get_module(self, module_name: str) -> ModuleType: def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError raise NotImplementedError
def copy_func(f):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
...@@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict ...@@ -28,7 +28,16 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url from .file_utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
add_start_docstrings_to_model_forward,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
)
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging from .utils import logging
...@@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC): ...@@ -85,13 +94,13 @@ class FlaxPreTrainedModel(ABC):
self.dtype = dtype self.dtype = dtype
# randomely initialized parameters # randomely initialized parameters
random_params = self.init(self.key, input_shape) random_params = self.init_weights(self.key, input_shape)
# save required_params as set # save required_params as set
self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
self.params = random_params self.params = random_params
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}") raise NotImplementedError(f"init method has to be implemented for {self}")
@property @property
...@@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC): ...@@ -394,3 +403,12 @@ class FlaxPreTrainedModel(ABC):
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
model_bytes = to_bytes(self.params) model_bytes = to_bytes(self.params)
f.write(model_bytes) f.write(model_bytes)
def overwrite_call_docstring(model_class, docstring):
# copy __call__ function to be sure docstring is changed only for this function
model_class.__call__ = copy_func(model_class.__call__)
# delete existing docstring
model_class.__call__.__doc__ = None
# set correct docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
"""Factory function to build auto-model classes.""" """Factory function to build auto-model classes."""
import functools
import types import types
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
...@@ -385,15 +385,6 @@ class _BaseAutoModelClass: ...@@ -385,15 +385,6 @@ class _BaseAutoModelClass:
) )
def copy_func(f):
""" Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
def insert_head_doc(docstring, head_doc=""): def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0: if len(head_doc) > 0:
return docstring.replace( return docstring.replace(
......
...@@ -26,7 +26,7 @@ from jax import lax ...@@ -26,7 +26,7 @@ from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
from ...utils import logging from ...utils import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
...@@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -91,6 +91,7 @@ BERT_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]``. config.max_position_embeddings - 1]``.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
...@@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ...@@ -477,49 +478,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
config_class = BertConfig config_class = BertConfig
base_model_prefix = "bert" base_model_prefix = "bert"
module_class: nn.Module = None
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids): def __init__(
if token_type_ids is None: self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
token_type_ids = jnp.ones_like(input_ids) ):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
if position_ids is None: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
...@@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel): ...@@ -531,9 +509,15 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
dropout_rng: PRNGKey = None, dropout_rng: PRNGKey = None,
train: bool = False, train: bool = False,
): ):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( # init input tensors if not passed
input_ids, attention_mask, token_type_ids, position_ids if token_type_ids is None:
) token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed # Handle any PRNG if needed
rngs = {} rngs = {}
...@@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module): ...@@ -576,49 +560,11 @@ class FlaxBertModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
""" "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class FlaxBertForPreTraining(FlaxBertPreTrainedModel): class FlaxBertModel(FlaxBertPreTrainedModel):
def __init__( module_class = FlaxBertModule
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForPreTrainingModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForPreTrainingModule(nn.Module): class FlaxBertForPreTrainingModule(nn.Module):
...@@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module): ...@@ -641,44 +587,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
return (prediction_scores, seq_relationship_score) return (prediction_scores, seq_relationship_score)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings(
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): """
def __init__( Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs sentence prediction (classification)` head.
): """,
module = FlaxBertForMaskedLMModule(config, **kwargs) BERT_START_DOCSTRING,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
module_class = FlaxBertForPreTrainingModule
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForMaskedLMModule(nn.Module): class FlaxBertForMaskedLMModule(nn.Module):
...@@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module): ...@@ -701,46 +618,9 @@ class FlaxBertForMaskedLMModule(nn.Module):
return (logits,) return (logits,)
@add_start_docstrings( @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
"""Bert Model with a `next sentence prediction (classification)` head on top. """, class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
BERT_START_DOCSTRING, module_class = FlaxBertForMaskedLMModule
)
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForNextSentencePredictionModule(nn.Module): class FlaxBertForNextSentencePredictionModule(nn.Module):
...@@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module): ...@@ -764,48 +644,11 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
""" """Bert Model with a `next sentence prediction (classification)` head on top. """,
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
def __init__( module_class = FlaxBertForNextSentencePredictionModule
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForSequenceClassificationModule(nn.Module): class FlaxBertForSequenceClassificationModule(nn.Module):
...@@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module): ...@@ -836,47 +679,13 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
""" """
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
softmax) e.g. for RocStories/SWAG tasks. output) e.g. for GLUE tasks.
""", """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
def __init__( module_class = FlaxBertForSequenceClassificationModule
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForMultipleChoiceModule(nn.Module): class FlaxBertForMultipleChoiceModule(nn.Module):
...@@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module): ...@@ -912,47 +721,19 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
""" """
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
Named-Entity-Recognition (NER) tasks. softmax) e.g. for RocStories/SWAG tasks.
""", """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
def __init__( module_class = FlaxBertForMultipleChoiceModule
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForTokenClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply( # adapt docstring slightly for FlaxBertForMultipleChoice
{"params": params or self.params}, overwrite_call_docstring(
jnp.array(input_ids, dtype="i4"), FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
jnp.array(attention_mask, dtype="i4"), )
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForTokenClassificationModule(nn.Module): class FlaxBertForTokenClassificationModule(nn.Module):
...@@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module): ...@@ -978,47 +759,13 @@ class FlaxBertForTokenClassificationModule(nn.Module):
@add_start_docstrings( @add_start_docstrings(
""" """
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). Named-Entity-Recognition (NER) tasks.
""", """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
def __init__( module_class = FlaxBertForTokenClassificationModule
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForQuestionAnsweringModule(nn.Module): class FlaxBertForQuestionAnsweringModule(nn.Module):
...@@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): ...@@ -1041,3 +788,14 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
return (start_logits, end_logits) return (start_logits, end_logits)
@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BERT_START_DOCSTRING,
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
module_class = FlaxBertForQuestionAnsweringModule
...@@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -441,40 +441,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: module_class: nn.Module = None
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
def __init__( def __init__(
self, self,
...@@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): ...@@ -484,23 +451,41 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
**kwargs **kwargs
): ):
module = FlaxRobertaModule(config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
token_type_ids = jnp.ones_like(input_ids)
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
input_ids, input_ids,
token_type_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None,
position_ids=None, position_ids=None,
params: dict = None, params: dict = None,
dropout_rng: PRNGKey = None, dropout_rng: PRNGKey = None,
train: bool = False, train: bool = False,
): ):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( # init input tensors if not passed
input_ids, attention_mask, token_type_ids, position_ids if token_type_ids is None:
) token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed # Handle any PRNG if needed
rngs = {} rngs = {}
...@@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module): ...@@ -541,3 +526,11 @@ class FlaxRobertaModule(nn.Module):
pooled = self.pooler(hidden_states) pooled = self.pooler(hidden_states)
return hidden_states, pooled return hidden_states, pooled
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaModule
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment