"...git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "877aec858e3885c956955b9aade3d38bd0ee6214"
Unverified Commit 98e409ab authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

albert flax (#13294)

* albert flax

* year -> 2021

* docstring updated for flax

* removed head_mask

* removed from_pt

* removed passing attention_mask to embedding layer
parent ee5b2457
...@@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support | | Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
+=============================+================+================+=================+====================+==============+ +=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | | | ALBERT | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ✅ | | BART | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -43,7 +43,8 @@ Tips: ...@@ -43,7 +43,8 @@ Tips:
similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same
number of (repeating) layers. number of (repeating) layers.
This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. The original code can be found `here This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. This model jax version was contributed by
`kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/google-research/ALBERT>`__. <https://github.com/google-research/ALBERT>`__.
AlbertConfig AlbertConfig
...@@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering ...@@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering
.. autoclass:: transformers.TFAlbertForQuestionAnswering .. autoclass:: transformers.TFAlbertForQuestionAnswering
:members: call :members: call
FlaxAlbertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertModel
:members: __call__
FlaxAlbertForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForPreTraining
:members: __call__
FlaxAlbertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForMaskedLM
:members: __call__
FlaxAlbertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForSequenceClassification
:members: __call__
FlaxAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForMultipleChoice
:members: __call__
FlaxAlbertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForTokenClassification
:members: __call__
FlaxAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAlbertForQuestionAnswering
:members: __call__
...@@ -1642,6 +1642,18 @@ if is_flax_available(): ...@@ -1642,6 +1642,18 @@ if is_flax_available():
"FlaxTopPLogitsWarper", "FlaxTopPLogitsWarper",
] ]
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.albert"].extend(
[
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]
)
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
...@@ -3152,6 +3164,16 @@ if TYPE_CHECKING: ...@@ -3152,6 +3164,16 @@ if TYPE_CHECKING:
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
) )
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
from .models.auto import ( from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
......
...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING ...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import ( from ...file_utils import (
_LazyModule, _LazyModule,
is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -65,6 +66,17 @@ if is_tf_available(): ...@@ -65,6 +66,17 @@ if is_tf_available():
"TFAlbertPreTrainedModel", "TFAlbertPreTrainedModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_albert"] = [
"FlaxAlbertForMaskedLM",
"FlaxAlbertForMultipleChoice",
"FlaxAlbertForPreTraining",
"FlaxAlbertForQuestionAnswering",
"FlaxAlbertForSequenceClassification",
"FlaxAlbertForTokenClassification",
"FlaxAlbertModel",
"FlaxAlbertPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
...@@ -103,6 +115,17 @@ if TYPE_CHECKING: ...@@ -103,6 +115,17 @@ if TYPE_CHECKING:
TFAlbertPreTrainedModel, TFAlbertPreTrainedModel,
) )
if is_flax_available():
from .modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
FlaxAlbertPreTrainedModel,
)
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -29,6 +29,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -29,6 +29,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[ [
# Base model mapping # Base model mapping
("distilbert", "FlaxDistilBertModel"), ("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"), ("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"), ("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"), ("big_bird", "FlaxBigBirdModel"),
...@@ -49,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -49,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[ [
# Model for pre-training mapping # Model for pre-training mapping
("albert", "FlaxAlbertForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"), ("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForPreTraining"), ("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"),
...@@ -65,6 +67,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ...@@ -65,6 +67,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"), ("distilbert", "FlaxDistilBertForMaskedLM"),
("albert", "FlaxAlbertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"), ("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"), ("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"),
...@@ -104,6 +107,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -104,6 +107,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"), ("distilbert", "FlaxDistilBertForSequenceClassification"),
("albert", "FlaxAlbertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"), ("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"),
...@@ -117,6 +121,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -117,6 +121,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
# Model for Question Answering mapping # Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"), ("distilbert", "FlaxDistilBertForQuestionAnswering"),
("albert", "FlaxAlbertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"), ("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"),
...@@ -130,6 +135,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -130,6 +135,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Token Classification mapping # Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"), ("distilbert", "FlaxDistilBertForTokenClassification"),
("albert", "FlaxAlbertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"), ("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"), ("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"),
...@@ -141,6 +147,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ...@@ -141,6 +147,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[ [
# Model for Multiple Choice mapping # Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"), ("distilbert", "FlaxDistilBertForMultipleChoice"),
("albert", "FlaxAlbertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"), ("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"),
......
...@@ -76,6 +76,74 @@ class FlaxPreTrainedModel: ...@@ -76,6 +76,74 @@ class FlaxPreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxAlbertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxAlbertForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAlbertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import AlbertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax.numpy as jnp
from transformers.models.albert.modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForPreTraining,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertModel,
)
class FlaxAlbertModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = AlbertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxAlbertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxAlbertModel,
FlaxAlbertForPreTraining,
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
FlaxAlbertForQuestionAnswering,
FlaxAlbertForSequenceClassification,
FlaxAlbertForTokenClassification,
FlaxAlbertForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxAlbertModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("albert-base-v2")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@require_flax
class FlaxAlbertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = FlaxAlbertModel.from_pretrained("albert-base-v2")
input_ids = np.array([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
output = model(input_ids, attention_mask=attention_mask)[0]
expected_shape = (1, 11, 768)
self.assertEqual(output.shape, expected_shape)
expected_slice = np.array(
[[[-0.6513, 1.5035, -0.2766], [-0.6515, 1.5046, -0.2780], [-0.6512, 1.5049, -0.2784]]]
)
self.assertTrue(jnp.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment