Unverified Commit 640e6fe1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)



* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements
Co-authored-by: default avatarTevenLeScao <teven.lescao@gmail.com>
parent 51adb97c
......@@ -13,9 +13,10 @@
Models
-----------------------------------------------------------------------------------------------------------------------
The base classes :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` implement the
common methods for loading/saving a model either from a local file or directory, or from a pretrained model
configuration provided by the library (downloaded from HuggingFace's AWS S3 repository).
The base classes :class:`~transformers.PreTrainedModel`, :class:`~transformers.TFPreTrainedModel`, and
:class:`~transformers.FlaxPreTrainedModel` implement the common methods for loading/saving a model either from a local
file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS
S3 repository).
:class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which
are common among all the models to:
......@@ -57,6 +58,13 @@ TFModelUtilsMixin
:members:
FlaxPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxPreTrainedModel
:members:
Generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng):
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, weight_sum = cross_entropy(logits, targets, token_mask)
return loss / weight_sum
......@@ -407,7 +407,7 @@ def eval_step(params, batch):
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
_, logits = model(**batch, params=params, train=False)
logits = model(**batch, params=params, train=False)[0]
return compute_metrics(logits, targets, token_mask)
......@@ -572,8 +572,13 @@ if __name__ == "__main__":
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1)
model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length))
model = FlaxBertForMaskedLM.from_pretrained(
"bert-base-cased",
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
# Setup optimizer
optimizer = Adam(
......
......@@ -97,7 +97,7 @@ _deps = [
"fastapi",
"filelock",
"flake8>=3.8.3",
"flax==0.2.2",
"flax>=0.2.2",
"fugashi>=1.0",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
......@@ -175,7 +175,7 @@ class DepsTableUpdateCommand(Command):
"deps = {",
entries,
"}",
""
"",
]
target = "src/transformers/dependency_versions_table.py"
print(f"updating {target}")
......
......@@ -945,6 +945,7 @@ else:
if is_flax_available():
from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
from .models.roberta import FlaxRobertaModel
......
......@@ -10,7 +10,7 @@ deps = {
"fastapi": "fastapi",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax==0.2.2",
"flax": "flax>=0.2.2",
"fugashi": "fugashi>=1.0",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
......@@ -40,8 +40,8 @@ deps = {
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinx": "sphinx==3.2.1",
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.0",
"tensorflow": "tensorflow>=2.0",
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
"tensorflow": "tensorflow>=2.0,<2.4",
"timeout-decorator": "timeout-decorator",
"tokenizers": "tokenizers==0.9.4",
"torch": "torch>=1.0",
......
......@@ -270,6 +270,7 @@ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "modelcard.json"
......
......@@ -15,64 +15,65 @@
import os
from abc import ABC, abstractmethod
from functools import partial
from pickle import UnpicklingError
from typing import Dict
from typing import Dict, Set, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.serialization import to_bytes
from flax.traverse_util import unflatten_dict
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
logger = logging.get_logger(__name__)
@jax.jit
def gelu(x):
r"""
Gaussian error linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
We explicitly use the approximation rather than the exact formulation for speed. For more information, see
`Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_, section 2.
"""
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
ACT2FN = {
"gelu": nn.gelu,
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": gelu,
"gelu_new": partial(nn.gelu, approximate=True),
}
class FlaxPreTrainedModel(ABC):
r"""
Base class for all models.
:class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles
methods for loading, downloading and saving models.
Class attributes (overridden by derived classes):
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
"""
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
model_class = None
def __init__(
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
self,
config: PretrainedConfig,
module: nn.Module,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
):
if config is None:
raise ValueError("config cannot be None")
if params is None:
raise ValueError("state cannot be None")
if module is None:
raise ValueError("module cannot be None")
# Those are private to be exposed as typed property on derived classes.
self._config = config
......@@ -80,9 +81,18 @@ class FlaxPreTrainedModel(ABC):
# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed)
self.params = params
self.dtype = dtype
# randomely initialized parameters
random_params = self.init(self.key, input_shape)
# save required_params as set
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
self.params = random_params
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")
@property
def config(self) -> PretrainedConfig:
return self._config
......@@ -91,24 +101,130 @@ class FlaxPreTrainedModel(ABC):
def module(self) -> nn.Module:
return self._module
@property
def params(self) -> Union[Dict, FrozenDict]:
return self._params
@property
def required_params(self) -> Set:
return self._required_params
@params.setter
def params(self, params: Union[Dict, FrozenDict]):
if isinstance(params, FrozenDict):
params = unfreeze(params)
param_keys = set(flatten_dict(params).keys())
if len(self.required_params - param_keys) > 0:
raise ValueError(
"Some parameters are missing. Make sure that `params` include the following "
f"parameters {self.required_params - param_keys}"
)
self._params = freeze(params)
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
**kwargs
):
r"""
Instantiate a pretrained Flax model from a pre-trained model configuration.
Instantiate a pretrained flax model from a pre-trained model configuration.
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
case, ``from_pt`` should be set to :obj:`True`.
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
model).
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (i.e., do not try to download the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples::
>>> from transformers import BertConfig, FlaxBertModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
>>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
>>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
>>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file('./pt_model/config.json')
>>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
"""
config = kwargs.pop("config", None)
# state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
# from_tf = kwargs.pop("from_tf", False)
from_pt = kwargs.pop("from_pt", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
......@@ -135,10 +251,28 @@ class FlaxPreTrainedModel(ABC):
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
else:
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_pt` set to False".format(
[FLAX_WEIGHTS_NAME, WEIGHTS_NAME],
pretrained_model_name_or_path,
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
revision=revision,
)
# redirect to the cache, if necessary
try:
......@@ -169,31 +303,96 @@ class FlaxPreTrainedModel(ABC):
# Instantiate model.
with open(resolved_archive_file, "rb") as state_f:
try:
from flax.serialization import from_bytes
state = from_bytes(cls.model_class, state_f)
except TypeError:
try:
if from_pt:
import torch
state = torch.load(state_f)
state = {k: v.numpy() for k, v in state.items()}
state = cls.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
state = convert_state_dict_from_pt(cls, state, config)
else:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(
f"Unable to convert model {archive_file} to Flax deserializable object. "
"Supported format are PyTorch archive or Flax msgpack"
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
)
# init random models
model = cls(config, *model_args, **model_kwargs)
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
# flatten dicts
state = flatten_dict(state)
random_state = flatten_dict(unfreeze(model.params))
missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params
# add missing keys as random parameters
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
return cls(config, state, *model_args, **model_kwargs)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
def save_pretrained(self, folder):
folder_abs = os.path.abspath(folder)
# set correct parameters
model.params = unflatten_dict(state)
return model
if not os.path.exists(folder_abs):
os.mkdir(folder_abs)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
Arguments:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
return
os.makedirs(save_directory, exist_ok=True)
# get abs dir
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.save_pretrained(save_directory)
# save model
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
"""
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
"""
state = {k: v.numpy() for k, v in state.items()}
state = model_class.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
return state
......@@ -27,15 +27,6 @@ from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
logger = logging.get_logger(__name__)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
FlaxBertModel.pretrained_model_archive_map,
FlaxRobertaModel.pretrained_model_archive_map,
]
for key, value, in pretrained_map.items()
)
FLAX_MODEL_MAPPING = OrderedDict(
[
(RobertaConfig, FlaxRobertaModel),
......@@ -114,10 +105,9 @@ class FlaxAutoModel(object):
organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this
case, ``from_tf`` should be set to True and a configuration object should be provided as ``config``
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model
using the provided conversion scripts and loading the PyTorch model afterwards.
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
argument.
model_args: (`optional`) Sequence of positional arguments:
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
......@@ -133,13 +123,6 @@ class FlaxAutoModel(object):
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionary for the model to use instead of a state dictionary loaded from saved
weights file. This option can be used if you want to create a model from a pretrained configuration but
load your own weights. In this case though, you should check if using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained` and
:func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
standard cache should not be used.
......
......@@ -19,10 +19,11 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...utils import logging
from .configuration_roberta import RobertaConfig
......@@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
ROBERTA_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
......@@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
......@@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module):
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
)(hidden_states, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
return layer_norm
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(
def __call__(self, hidden_states):
hidden_states = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
......@@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
hidden_states = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
return hidden_states
class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(
self.num_heads,
self.head_size,
......@@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module):
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
......@@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
......@@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module):
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
......@@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
......@@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_state.shape[-1],
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
......@@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module):
return nn.tanh(out)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxPreTrainedModel):
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
model_class = FlaxRobertaModule
config_class = RobertaConfig
base_model_prefix = "roberta"
......@@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
def __init__(
self,
config: RobertaConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
......@@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
**kwargs,
)
super().__init__(config, module, state, seed)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
......@@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
rngs=rngs,
)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
return input_ids, attention_mask, token_type_ids, position_ids
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
Args:
input_ids: jnp.ndarray
padding_idx: int
if not self.add_pooling_layer:
return encoder
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
......@@ -2,6 +2,15 @@
from ..file_utils import requires_flax
class FlaxPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
FLAX_MODEL_MAPPING = None
......
......@@ -14,14 +14,16 @@
import unittest
import numpy as np
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
class FlaxBertModelTester(unittest.TestCase):
......@@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBertModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bert-base-cased")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
......@@ -13,6 +13,7 @@
# limitations under the License.
import random
import tempfile
import numpy as np
......@@ -26,7 +27,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.traverse_util import unflatten_dict
from transformers.modeling_flax_utils import convert_state_dict_from_pt
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
......@@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None):
return attn_mask
def convert_pt_model_to_flax(pt_model, config, flax_model_cls):
state = pt_model.state_dict()
state = {k: v.numpy() for k, v in state.items()}
state = flax_model_cls.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
return flax_model_cls(config, state, dtype=jnp.float32)
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).sum()
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch
......@@ -86,30 +79,54 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = convert_pt_model_to_flax(pt_model, config, model_class)
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
fx_model = model_class(config, dtype=jnp.float32)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
@require_torch
def test_jit_compilation(self):
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
outputs = model(**inputs_dict)
model = convert_pt_model_to_flax(pt_model, config, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**inputs_dict)
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
......@@ -125,3 +142,14 @@ class FlaxModelTesterMixin:
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__
module_class_name = (
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
)
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
module_cls = getattr(bert_modeling_flax_module, module_class_name)
self.assertIsNotNone(module_cls)
......@@ -14,8 +14,10 @@
import unittest
import numpy as np
from transformers import RobertaConfig, is_flax_available
from transformers.testing_utils import require_flax
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
......@@ -109,3 +111,10 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FlaxRobertaModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("roberta-base")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment