Unverified Commit e87505f3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add other BERT classes (#10977)

* add first code structures

* add all bert models

* add to init and docs

* correct docs

* make style
parent e031162a
...@@ -209,8 +209,50 @@ FlaxBertModel ...@@ -209,8 +209,50 @@ FlaxBertModel
:members: __call__ :members: __call__
FlaxBertForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForPreTraining
:members: __call__
FlaxBertForMaskedLM FlaxBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForMaskedLM .. autoclass:: transformers.FlaxBertForMaskedLM
:members: __call__ :members: __call__
FlaxBertForNextSentencePrediction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForNextSentencePrediction
:members: __call__
FlaxBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForSequenceClassification
:members: __call__
FlaxBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForMultipleChoice
:members: __call__
FlaxBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForTokenClassification
:members: __call__
FlaxBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBertForQuestionAnswering
:members: __call__
...@@ -1290,7 +1290,19 @@ else: ...@@ -1290,7 +1290,19 @@ else:
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"]) _import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
_import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"]) _import_structure["models.bert"].extend(
[
"FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction",
"FlaxBertForPreTraining",
"FlaxBertForQuestionAnswering",
"FlaxBertForSequenceClassification",
"FlaxBertForTokenClassification",
"FlaxBertModel",
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.roberta"].append("FlaxRobertaModel") _import_structure["models.roberta"].append("FlaxRobertaModel")
else: else:
from .utils import dummy_flax_objects from .utils import dummy_flax_objects
...@@ -2372,7 +2384,17 @@ if TYPE_CHECKING: ...@@ -2372,7 +2384,17 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel from .models.bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.roberta import FlaxRobertaModel from .models.roberta import FlaxRobertaModel
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
......
...@@ -70,8 +70,17 @@ if is_tf_available(): ...@@ -70,8 +70,17 @@ if is_tf_available():
] ]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_bert"] = ["FlaxBertForMaskedLM", "FlaxBertModel"] _import_structure["modeling_flax_bert"] = [
"FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction",
"FlaxBertForPreTraining",
"FlaxBertForQuestionAnswering",
"FlaxBertForSequenceClassification",
"FlaxBertForTokenClassification",
"FlaxBertModel",
"FlaxBertPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
...@@ -115,7 +124,17 @@ if TYPE_CHECKING: ...@@ -115,7 +124,17 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel from .modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
FlaxBertPreTrainedModel,
)
else: else:
import importlib import importlib
......
...@@ -445,6 +445,30 @@ class FlaxBertOnlyMLMHead(nn.Module): ...@@ -445,6 +445,30 @@ class FlaxBertOnlyMLMHead(nn.Module):
return hidden_states return hidden_states
class FlaxBertOnlyNSPHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, pooled_output):
return self.seq_relationship(pooled_output)
class FlaxBertPreTrainingHeads(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, hidden_states, pooled_output):
prediction_scores = self.predictions(hidden_states)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class FlaxBertPreTrainedModel(FlaxPreTrainedModel): class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
...@@ -551,6 +575,73 @@ class FlaxBertModule(nn.Module): ...@@ -551,6 +575,73 @@ class FlaxBertModule(nn.Module):
return hidden_states, pooled return hidden_states, pooled
@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForPreTrainingModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
hidden_states, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
return (prediction_scores, seq_relationship_score)
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
def __init__( def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
...@@ -559,6 +650,7 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): ...@@ -559,6 +650,7 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__( def __call__(
self, self,
input_ids, input_ids,
...@@ -594,24 +686,358 @@ class FlaxBertForMaskedLMModule(nn.Module): ...@@ -594,24 +686,358 @@ class FlaxBertForMaskedLMModule(nn.Module):
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
self.bert = FlaxBertModule( self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
config=self.config, self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
add_pooling_layer=False,
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
logits = self.cls(hidden_states)
return (logits,)
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
) )
seq_relationship_scores = self.cls(pooled_output)
return (seq_relationship_scores,)
@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.cls = FlaxBertOnlyMLMHead( self.classifier = nn.Dense(
config=self.config, self.config.num_labels,
dtype=self.dtype, dtype=self.dtype,
) )
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
return (logits,)
@add_start_docstrings(
"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
num_choices = input_ids.shape[1]
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
# Model
_, pooled_output = self.bert(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
logits = self.classifier(pooled_output)
reshaped_logits = logits.reshape(-1, num_choices)
return (reshaped_logits,)
@add_start_docstrings(
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BERT_START_DOCSTRING,
)
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForTokenClassificationModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForTokenClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__( def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
): ):
# Model # Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.cls(hidden_states) logits = self.classifier(hidden_states)
return (logits,) return (logits,)
@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BERT_START_DOCSTRING,
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
return (start_logits, end_logits)
...@@ -32,6 +32,52 @@ class FlaxBertForMaskedLM: ...@@ -32,6 +32,52 @@ class FlaxBertForMaskedLM:
requires_flax(self) requires_flax(self)
class FlaxBertForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForPreTraining:
def __init__(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForTokenClassification:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertModel: class FlaxBertModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_flax(self) requires_flax(self)
...@@ -41,6 +87,15 @@ class FlaxBertModel: ...@@ -41,6 +87,15 @@ class FlaxBertModel:
requires_flax(self) requires_flax(self)
class FlaxBertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxRobertaModel: class FlaxRobertaModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_flax(self) requires_flax(self)
......
...@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_ ...@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available(): if is_flax_available():
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel from transformers.models.bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForTokenClassification,
FlaxBertModel,
)
class FlaxBertModelTester(unittest.TestCase): class FlaxBertModelTester(unittest.TestCase):
...@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase): ...@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_choices=4,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase): ...@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
...@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase): ...@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax @require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else () all_model_classes = (
(
FlaxBertModel,
FlaxBertForPreTraining,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction,
FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self): def setUp(self):
self.model_tester = FlaxBertModelTester(self) self.model_tester = FlaxBertModelTester(self)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import random import random
import tempfile import tempfile
...@@ -65,6 +66,18 @@ class FlaxModelTesterMixin: ...@@ -65,6 +66,18 @@ class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict)
# hack for now until we have AutoModel classes
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
}
return inputs_dict
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max() diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
...@@ -75,6 +88,7 @@ class FlaxModelTesterMixin: ...@@ -75,6 +88,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval() pt_model = pt_model_class(config).eval()
...@@ -83,12 +97,12 @@ class FlaxModelTesterMixin: ...@@ -83,12 +97,12 @@ class FlaxModelTesterMixin:
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()} pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict) fx_outputs = fx_model(**prepared_inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3) self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
...@@ -97,7 +111,7 @@ class FlaxModelTesterMixin: ...@@ -97,7 +111,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
self.assertEqual( self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
) )
...@@ -111,13 +125,14 @@ class FlaxModelTesterMixin: ...@@ -111,13 +125,14 @@ class FlaxModelTesterMixin:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
model = model_class(config) model = model_class(config)
outputs = model(**inputs_dict) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname) model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**inputs_dict) outputs_loaded = model_loaded(**prepared_inputs_dict)
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3) self.assert_almost_equals(output_loaded, output, 5e-3)
...@@ -126,6 +141,7 @@ class FlaxModelTesterMixin: ...@@ -126,6 +141,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config) model = model_class(config)
@jax.jit @jax.jit
...@@ -134,10 +150,10 @@ class FlaxModelTesterMixin: ...@@ -134,10 +150,10 @@ class FlaxModelTesterMixin:
with self.subTest("JIT Disabled"): with self.subTest("JIT Disabled"):
with jax.disable_jit(): with jax.disable_jit():
outputs = model_jitted(**inputs_dict) outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Enabled"): with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**inputs_dict) jitted_outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs)) self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment