Unverified Commit 8f8f8d99 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Integrate Bert-like model on Flax runtime. (#3722)



* WIP flax bert

* Initial commit Bert Jax/Flax implementation.

* Embeddings working and equivalent to PyTorch.

* Move embeddings in its own module BertEmbeddings

* Added jax.jit annotation on forward call

* BertEncoder on par with PyTorch ! :D

* Add BertPooler on par with PyTorch !!

* Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer.

* Fix pooled output to take only the first token of the sequence.

* Refactoring to use BertConfig from transformers.

* Renamed FXBertModel to FlaxBertModel

* Model is now initialized in FlaxBertModel constructor and reused.

* WIP JaxPreTrainedModel

* Cleaning up the code of FlaxBertModel

* Added ability to load Flax model saved through save_pretrained()

* Added ability to convert Pytorch Bert model to FlaxBert

* FlaxBert can now load every Pytorch Bert model with on-the-fly conversion

* Fix hardcoded shape values in conversion scripts.

* Improve the way we handle LayerNorm conversion from PyTorch to Flax.

* Added positional embeddings as parameter of BertModel with default to np.arange.

* Let's roll FlaxRoberta !

* Fix missing position_ids parameters on predict for Bert

* Flax backend now supports batched inputs
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Make it possible to load msgpacked model on convert from pytorch in last resort.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Moved save_pretrained to Jax base class along with more constructor parameters.

* Use specialized, model dependent conversion functio.

* Expose `is_flax_available` in file_utils.

* Added unittest for Flax models.

* Added run_tests_flax to the CI.

* Introduce FlaxAutoModel

* Added more unittests

* Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model.

* Addressing review comments.

* Expose seed in both Bert and Roberta

* Fix typo suggested by @stefan-it
Co-Authored-By: default avatarStefan Schweter <stefan@schweter.it>

* Attempt to make style

* Attempt to make style in tests too

* Added jax & jaxlib to the flax optional dependencies.

* Attempt to fix flake8 warnings ...

* Redo black again and again

* When black and flake8 fight each other for a space ... 💥 💥 💥

* Try removing trailing comma to make both black and flake happy!

* Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉



* Fix another invalid import in flax_roberta test

* Bump and pin flax release to 0.1.0.

* Make flake8 happy, remove unused jax import

* Change the type of the catch for msgpack.

* Remove unused import.

* Put seed as optional constructor parameter.

* trigger ci again

* Fix too much parameters in BertAttention.

* Formatting.

* Simplify Flax unittests to avoid machine crashes.

* Fix invalid number of arguments when raising issue for an unknown model.

* Address @bastings comment in PR, moving jax.jit decorated outside of __call__

* Fix incorrect path to require_flax/require_pytorch functions.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to make style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correct rebasing of circle-ci dependencies
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix import sorting.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix unused imports.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Again import sorting...
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Installing missing nlp dependency for flax unittests.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Fix laoding of model for Flax implementations.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* jit the inner function call to make JAX-compatible
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Format !
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Flake one more time 🎶

Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Rewrites BERT in Flax to the new Linen API (#7211)

* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests

* Fix CI with change of name of nlp (#7054)

* nlp -> datasets

* More nlp -> datasets

* Woopsie

* More nlp -> datasets

* One last

* Expose `is_flax_available` in file_utils.

* Added run_tests_flax to the CI.

* Attempt to make style

* trigger ci again

* Fix import sorting.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Revert "Rewrites BERT in Flax to the new Linen API (#7211)"

This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0.

* Remove jnp.lax references
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Reintroduce Linen changes ...
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use jax native's gelu function.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Renaming BertModel to BertModule to highlight the fact this is the Flax Module object.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused variable in BertModule.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused variable in BertModule again
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to have is_flax_available working again.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Introduce JAX TensorType
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Improve ImportError message when trying to convert to various TensorType format.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Makes Flax model jittable.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Ensure flax models are jittable in unittests.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Remove unused imports.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Ensure jax imports are guarded behind is_flax_available.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style again
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style again again
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style again again again
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Update src/transformers/file_utils.py
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>

* Bump flax to it's latest version
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>

* Bump jax version to at least 0.2.0
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Update the unittest to use TensorType.JAX
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* isort import in tests.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Match new flax parameters name "params"
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused imports.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Add flax models to transformers __init__
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to address all CI related comments.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correct circle.yml indent.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correct circle.yml indent (2)
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Remove coverage from flax tests
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Addressing many naming suggestions from comments
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Simplify for loop logic to interate over layers in FlaxBertLayerCollection
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* use f-string syntax for formatting logs.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use config property from FlaxPreTrainedModel.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* use "cls_token" instead of "first_token" variable name.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* use "hidden_state" instead of "h" variable name.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correct class reference in docstring to link to Flax related modules.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added HF + Google Flax team copyright.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta independent from Bert
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Move activation functions to flax_utils.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Move activation functions to flax_utils for bert.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Added docstring for BERT
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Update import for Bert and Roberta tokenizers
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* fix-copies
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Correct FlaxRobertaLayer to match PyTorch.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use the same store_artifact for flax unittest
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Style.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Make sure gradient are disabled only locally for flax unittest using torch equivalence.
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>

* Use relative imports
Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: default avatarStefan Schweter <stefan@schweter.it>
Co-authored-by: default avatarMarc van Zee <marcvanzee@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0193c829
......@@ -139,6 +139,31 @@ jobs:
- store_artifacts:
path: ~/transformers/output.txt
destination: test_output.txt
run_tests_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.3-flax-{{ checksum "setup.py" }}
- v0.3-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install git+https://github.com/huggingface/datasets
- run: sudo pip install .[flax,sklearn,torch,testing]
- save_cache:
key: v0.3-flax-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -n 8 --dist=loadfile -rA -s ./tests/ | tee output.txt
- store_artifacts:
path: ~/transformers/output.txt
destination: test_output.txt
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
......@@ -305,6 +330,7 @@ workflows:
- run_tests_torch_and_tf
- run_tests_torch
- run_tests_tf
- run_tests_flax
- build_doc
- deploy_doc: *workflow_filters
tpu_testing_jobs:
......
......@@ -87,6 +87,7 @@ extras["tf-cpu"] = [
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
]
extras["torch"] = ["torch>=1.0"]
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
......
......@@ -103,6 +103,7 @@ from .file_utils import (
is_apex_available,
is_datasets_available,
is_faiss_available,
is_flax_available,
is_psutil_available,
is_py3nvml_available,
is_sentencepiece_available,
......@@ -817,6 +818,10 @@ else:
from .utils.dummy_tf_objects import *
if is_flax_available():
from .modeling_flax_bert import FlaxBertModel
from .modeling_flax_roberta import FlaxRobertaModel
if not is_tf_available() and not is_torch_available():
logger.warning(
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
......
......@@ -34,10 +34,13 @@ from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
import torch
_torch_available = True # pylint: disable=invalid-name
......@@ -52,7 +55,7 @@ try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
......@@ -65,6 +68,22 @@ except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
import flax
import jax
logger.info("JAX version {}, Flax: available".format(jax.__version__))
logger.info("Flax available: {}".format(flax))
_flax_available = True
else:
_flax_available = False
except ImportError:
_flax_available = False # pylint: disable=invalid-name
try:
import datasets # noqa: F401
......@@ -213,6 +232,10 @@ def is_tf_available():
return _tf_available
def is_flax_available():
return _flax_available
def is_torch_tpu_available():
return _torch_tpu_available
......
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
from collections import OrderedDict
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
from .configuration_utils import PretrainedConfig
from .modeling_flax_bert import FlaxBertModel
from .modeling_flax_roberta import FlaxRobertaModel
from .utils import logging
logger = logging.get_logger(__name__)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
FlaxBertModel.pretrained_model_archive_map,
FlaxRobertaModel.pretrained_model_archive_map,
]
for key, value, in pretrained_map.items()
)
MODEL_MAPPING = OrderedDict(
[
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
]
)
class FlaxAutoModel(object):
r"""
:class:`~transformers.FlaxAutoModel` is a generic model class
that will be instantiated as one of the base model classes of the library
when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)`
or the `FlaxAutoModel.from_config(config)` class methods.
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"FlaxAutoModel is designed to be instantiated "
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
"`FlaxAutoModel.from_config(config)` methods."
)
@classmethod
def from_config(cls, config):
r"""Instantiates one of the base model classes of the library
from a configuration.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model)
Examples:
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = FlaxAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class(config)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
based on the `model_type` property of the config object, or when it's missing,
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Args:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args: (`optional`) Sequence of positional arguments:
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and :func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
These arguments will be passed to the configuration and the model.
Examples::
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
)
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact
from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`.
See :meth:`transformers.PreTrainedTokenizer.encode` and
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class FlaxBertLayerNorm(nn.Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
Operates on the last axis of the input data.
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@compact
def __call__(self, x):
"""Applies layer normalization on the input.
It normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
return y
class FlaxBertEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff
as Flax's one use 'embedding' for the parameter name
and PyTorch use 'weight'
"""
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
class FlaxBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
@compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int
@compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
return layer_norm
class FlaxBertIntermediate(nn.Module):
output_size: int
@compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
return gelu(dense)
class FlaxBertOutput(nn.Module):
@compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
return hidden_state
class FlaxBertLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxBertOutput(name="output")(intermediate, attention)
return output
class FlaxBertLayerCollection(nn.Module):
"""
Stores N BertLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
return input_i
class FlaxBertEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxBertLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
return layer
class FlaxBertPooler(nn.Module):
@compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
class FlaxBertModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
pooled = FlaxBertPooler(name="pooler")(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in `Attention is all you need
<https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
model_class = FlaxBertModule
config_class = BertConfig
base_model_prefix = "bert"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxBertModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
)
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import compact
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
ROBERTA_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
ROBERTA_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.RobertaTokenizer`.
See :meth:`transformers.PreTrainedTokenizer.encode` and
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
class FlaxRobertaLayerNorm(nn.Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
Operates on the last axis of the input data.
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32
bias: bool = True
scale: bool = True
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@compact
def __call__(self, x):
"""Applies layer normalization on the input.
It normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Args:
x: the inputs
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
bias: If True, bias (beta) is added.
scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
return y
# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
class FlaxRobertaEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff
as Flax's one use 'embedding' for the parameter name
and PyTorch use 'weight'
"""
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
@compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
@compact
def __call__(self, hidden_state, attention_mask):
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
return layer_norm
# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
@compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
return gelu(dense)
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
@compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
return hidden_state
class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
)
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
output = FlaxRobertaOutput(name="output")(intermediate, attention)
return output
# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
class FlaxRobertaLayerCollection(nn.Module):
"""
Stores N RobertaLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
input_i = layer(input_i, attention_mask)
return input_i
# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
class FlaxRobertaEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxRobertaLayerCollection(
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
)(hidden_state, attention_mask)
return layer
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
@compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
@compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
)(embeddings, attention_mask)
pooled = FlaxRobertaPooler(name="pooler")(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
model_class = FlaxRobertaModule
config_class = RobertaConfig
base_model_prefix = "roberta"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
model = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
super().__init__(config, model, state, seed)
@property
def module(self) -> nn.Module:
return self._module
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = np.arange(
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
)
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from pickle import UnpicklingError
from typing import Dict
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.serialization import to_bytes
from flax.traverse_util import unflatten_dict
from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
logger = logging.get_logger(__name__)
@jax.jit
def gelu(x):
r"""Gaussian error linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
We explicitly use the approximation rather than the exact formulation for
speed. For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
"""
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
ACT2FN = {
"gelu": nn.gelu,
"relu": nn.relu,
"swish": nn.swish,
"gelu_new": gelu,
}
class FlaxPreTrainedModel(ABC):
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
model_class = None
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
if config is None:
raise ValueError("config cannot be None")
if module is None:
raise ValueError("module cannot be None")
if params is None:
raise ValueError("state cannot be None")
# Those are private to be exposed as typed property on derived classes.
self._config = config
self._module = module
# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed)
self.params = params
self.model = module
@property
def config(self) -> PretrainedConfig:
return self._config
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiate a pretrained Flax model from a pre-trained model configuration.
"""
config = kwargs.pop("config", None)
# state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
# from_tf = kwargs.pop("from_tf", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
**kwargs,
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
else:
msg = (
f"Model name '{pretrained_model_name_or_path}' "
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
f"We assumed '{archive_file}' was a path or url to model weight files but "
"couldn't find any such file at this path or url."
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}")
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# Instantiate model.
with open(resolved_archive_file, "rb") as state_f:
try:
from flax.serialization import from_bytes
state = from_bytes(cls.model_class, state_f)
except TypeError:
try:
import torch
state = torch.load(state_f)
state = {k: v.numpy() for k, v in state.items()}
state = cls.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
except UnpicklingError:
raise EnvironmentError(
f"Unable to convert model {archive_file} to Flax deserializable object. "
"Supported format are PyTorch archive or Flax msgpack"
)
return cls(config, state, *model_args, **model_kwargs)
def save_pretrained(self, folder):
folder_abs = os.path.abspath(folder)
if not os.path.exists(folder_abs):
os.mkdir(folder_abs)
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)
......@@ -13,6 +13,7 @@ from pathlib import Path
from .file_utils import (
_datasets_available,
_faiss_available,
_flax_available,
_sentencepiece_available,
_tf_available,
_tokenizers_available,
......@@ -115,6 +116,18 @@ def require_tf(test_case):
return test_case
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax
These tests are skipped when one / both are not installed
"""
if not _flax_available:
test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case
def require_sentencepiece(test_case):
"""
Decorator marking a test that requires SentencePiece.
......
......@@ -33,6 +33,7 @@ from .file_utils import (
add_end_docstrings,
cached_path,
hf_bucket_url,
is_flax_available,
is_remote_url,
is_tf_available,
is_tokenizers_available,
......@@ -47,6 +48,8 @@ if is_tf_available():
if is_torch_available():
import torch
if is_flax_available():
import jax.numpy as jnp
if is_tokenizers_available():
from tokenizers import AddedToken
......@@ -143,6 +146,7 @@ class TensorType(ExplicitEnum):
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
class CharSpan(NamedTuple):
......@@ -559,18 +563,27 @@ class BatchEncoding(UserDict):
tensor_type = TensorType(tensor_type)
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW and is_tf_available():
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
as_tensor = tf.constant
elif tensor_type == TensorType.PYTORCH and is_torch_available():
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
as_tensor = torch.tensor
elif tensor_type == TensorType.NUMPY:
as_tensor = np.asarray
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
as_tensor = jnp.array
else:
raise ImportError(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
tensor_type
)
)
as_tensor = np.asarray
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# "Unable to convert output to tensors format {}".format(tensor_type)
# )
# Do the tensor conversion in batch
for key, value in self.items():
......
import unittest
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
from transformers.testing_utils import require_flax, slow
if is_flax_available():
import jax
from transformers.modeling_flax_auto import FlaxAutoModel
from transformers.modeling_flax_bert import FlaxBertModel
from transformers.modeling_flax_roberta import FlaxRobertaModel
@require_flax
class FlaxAutoModelTest(unittest.TestCase):
@slow
def test_bert_from_pretrained(self):
for model_name in ["bert-base-cased", "bert-large-uncased"]:
with self.subTest(model_name):
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = FlaxAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, FlaxBertModel)
@slow
def test_roberta_from_pretrained(self):
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
with self.subTest(model_name):
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = FlaxAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, FlaxRobertaModel)
@slow
def test_bert_jax_jit(self):
for model_name in ["bert-base-cased", "bert-large-uncased"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxBertModel.from_pretrained(model_name)
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
@jax.jit
def eval(**kwargs):
return model(**kwargs)
eval(**tokens).block_until_ready()
@slow
def test_roberta_jax_jit(self):
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxRobertaModel.from_pretrained(model_name)
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
@jax.jit
def eval(**kwargs):
return model(**kwargs)
eval(**tokens).block_until_ready()
import unittest
from numpy import ndarray
from transformers import TensorType, is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
from transformers.tokenization_bert_fast import BertTokenizerFast
if is_flax_available():
from transformers.modeling_flax_bert import FlaxBertModel
if is_torch_available():
import torch
from transformers.modeling_bert import BertModel
@require_flax
@require_torch
class FlaxBertModelTest(unittest.TestCase):
def test_from_pytorch(self):
with torch.no_grad():
with self.subTest("bert-base-cased"):
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
fx_model = FlaxBertModel.from_pretrained("bert-base-cased")
pt_model = BertModel.from_pretrained("bert-base-cased")
# Check for simple input
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
import unittest
from numpy import ndarray
from transformers import TensorType, is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
from transformers.tokenization_roberta_fast import RobertaTokenizerFast
if is_flax_available():
from transformers.modeling_flax_roberta import FlaxRobertaModel
if is_torch_available():
import torch
from transformers.modeling_roberta import RobertaModel
@require_flax
@require_torch
class FlaxRobertaModelTest(unittest.TestCase):
def test_from_pytorch(self):
with torch.no_grad():
with self.subTest("roberta-base"):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
pt_model = RobertaModel.from_pretrained("roberta-base")
# Check for simple input
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment