Unverified Commit 4eef5889 authored by Teven's avatar Teven Committed by GitHub
Browse files

Adding performer fine-tuning research exampke (#9239)

* added run_mlm_performer.py research example

* make styke

* make styke

* Added a README !
parent 9a12b969
# Performer fine-tuning
Example authors: @TevenLeScao, @Patrickvonplaten
Paper authors: Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, Adrian Weller
## Requirements
`datasets`, `flax` and `jax`. `wandb` integration is built-in if you want to use it.
## Examples
`sanity_script.sh` will launch performer fine-tuning from the bert-base-cased checkpoint on the Simple Wikipedia dataset (a small, easy-language English Wikipedia) from `datasets`.
`full_script.sh` will launch performer fine-tuning from the bert-large-cased checkpoint on the English Wikipedia dataset from `datasets`.
Here are a few key arguments:
- Remove the `--performer` argument to use a standard Bert model.
- Add `--reinitialize` to start from a blank model rather than a Bert checkpoint.
- You may change the Bert size by passing a different [checkpoint](https://huggingface.co/transformers/pretrained_models.html) to the `--model_name_or_path` argument.
- Passing your user name to the `--wandb_user_name` argument will trigger weights and biases logging.
- You can choose a dataset with `--dataset_name` and `--dataset_config`. Our [viewer](https://huggingface.co/datasets/viewer/) will help you find what you need.
\ No newline at end of file
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.en --model_name_or_path bert-large-cased --tokenizer_name bert-large-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
\ No newline at end of file
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from modeling_flax_performer_utils import make_fast_softmax_attention
from transformers.file_utils import add_start_docstrings
from transformers.modeling_flax_utils import ACT2FN
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead, FlaxBertPreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class FlaxPerformerLayerNorm(nn.Module):
"""
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
bias_init: jnp.ndarray = nn.initializers.zeros
scale_init: jnp.ndarray = nn.initializers.ones
@nn.compact
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
return y
class FlaxPerformerEmbedding(nn.Module):
"""
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
use 'weight'
"""
vocab_size: int
hidden_size: int
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
class FlaxPerformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embed
w_emb = FlaxPerformerEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
jnp.atleast_2d(input_ids.astype("i4"))
)
p_emb = FlaxPerformerEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
jnp.atleast_2d(position_ids.astype("i4"))
)
t_emb = FlaxPerformerEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
jnp.atleast_2d(token_type_ids.astype("i4"))
)
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
# Layer Norm
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(summed_emb)
return layer_norm
class FlaxPerformerAttention(nn.Module):
num_heads: int
head_size: int
@nn.compact
def __call__(self, hidden_state, attention_mask):
single_head_dim = self.head_size // self.num_heads
fast_softmax_attention = make_fast_softmax_attention(qkv_dim=single_head_dim)
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads, qkv_features=self.head_size, name="self", attention_fn=fast_softmax_attention
)(hidden_state, attention_mask)
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(self_att + hidden_state)
return layer_norm
class FlaxPerformerIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
return ACT2FN[self.hidden_act](dense)
class FlaxPerformerOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output):
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
hidden_state = FlaxPerformerLayerNorm(name="layer_norm")(hidden_state + attention_output)
return hidden_state
class FlaxPerformerLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
@nn.compact
def __call__(self, hidden_state, attention_mask):
attention = FlaxPerformerAttention(self.num_heads, self.head_size, name="attention")(
hidden_state, attention_mask
)
intermediate = FlaxPerformerIntermediate(
self.intermediate_size, name="intermediate", hidden_act=self.hidden_act
)(attention)
output = FlaxPerformerOutput(name="output")(intermediate, attention)
return output
class FlaxPerformerLayerCollection(nn.Module):
"""
Stores N BertLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
@nn.compact
def __call__(self, inputs, attention_mask):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxPerformerLayer(
self.num_heads, self.head_size, self.intermediate_size, hidden_act=self.hidden_act, name=f"{i}"
)
input_i = layer(input_i, attention_mask)
return input_i
class FlaxPerformerEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
@nn.compact
def __call__(self, hidden_state, attention_mask):
layer = FlaxPerformerLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
name="layer",
hidden_act=self.hidden_act,
)(hidden_state, attention_mask)
return layer
class FlaxPerformerPooler(nn.Module):
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
return jax.lax.tanh(out)
class FlaxPerformerModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
add_pooling_layer: bool = True
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
# Embedding
embeddings = FlaxPerformerEmbeddings(
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
)(input_ids, token_type_ids, position_ids, attention_mask)
# N stacked encoding layers
encoder = FlaxPerformerEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
name="encoder",
)(embeddings, attention_mask)
if not self.add_pooling_layer:
return encoder
pooled = FlaxPerformerPooler(name="pooler")(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class FlaxPerformerModel(FlaxBertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
"""
model_class = FlaxPerformerModule
config_class = BertConfig
base_model_prefix = "bert"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxPerformerModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@property
def module(self) -> nn.Module:
return self._module
def __call__(
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
rng=rngs,
)
class FlaxPerformerForMaskedLM(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxPerformerForMaskedLMModule(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
max_length=config.max_position_embeddings,
hidden_act=config.hidden_act,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def __call__(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
params: dict = None,
train: bool = False,
dropout_rng: PRNGKey = None,
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
rngs=rngs,
)
class FlaxPerformerForMaskedLMModule(nn.Module):
vocab_size: int
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
hidden_act: str
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
encoder = FlaxPerformerModule(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
type_vocab_size=self.type_vocab_size,
max_length=self.max_length,
num_encoder_layers=self.num_encoder_layers,
num_heads=self.num_heads,
head_size=self.hidden_size,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
add_pooling_layer=False,
name="bert",
)(input_ids, attention_mask, token_type_ids, position_ids)
# Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
logits = FlaxBertOnlyMLMHead(
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
)(encoder)
return (logits,)
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMPORTANT:
This code was copied from
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and
update accordingly
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix.
"""
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
import abc
import functools
from collections.abc import Iterable # pylint: disable=g-importing-member
import numpy as onp
from absl import logging
import jax
import jax.numpy as jnp
from jax import lax, random
def nonnegative_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001
):
"""
Constructs nonnegative kernel features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
is_query: predicate indicating whether input data corresponds to queries or
keys
normalize_data: predicate indicating whether data should be normalized,
eps: numerical stabilizer
Returns:
Random features for fast softmax attention.
"""
del attention_dims_t
if normalize_data:
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
# w_norm = w * data_normalizer for w in {q,k}.
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
diag_data = jnp.square(data)
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
data_dash = ratio * (
jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps
)
else:
data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps)
return data_dash
def sincos_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True
):
"""
Constructs kernel sin-cos features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast softmax attention.
"""
if normalize_data:
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
data_dash_cos = ratio * jnp.cos(data_dash)
data_dash_sin = ratio * jnp.sin(data_dash)
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
# Constructing D_data and data^{'}
diag_data = jnp.square(data)
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
# Additional renormalization for numerical stability
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
diag_data -= data_renormalizer
diag_data = jnp.exp(diag_data)
data_prime = data_dash * diag_data
return data_prime
def generalized_kernel_feature_creator(
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
):
"""
Constructs kernel features for fast generalized attention
Args:
data: input for which features are computes
projection_matrix: matrix used to compute features
batch_dims_t: tuple of batch dimensions
precision: precision parameter
kernel_fn: kernel function used
kernel_epsilon: additive positive term added to every feature for numerical
stability
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast generalized attention.
"""
if normalize_data:
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
else:
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime
def make_fast_softmax_attention(
qkv_dim,
renormalize_attention=True,
numerical_stabilizer=0.000001,
nb_features=256,
ortho_features=True,
ortho_scaling=0.0,
redraw_features=True,
unidirectional=False,
nonnegative_features=True,
lax_scan_unroll=1,
):
"""Construct a fast softmax attention method."""
logging.info(
"Fast softmax attention: %s features and orthogonal=%s, renormalize=%s",
nb_features,
ortho_features,
renormalize_attention,
)
if ortho_features:
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling)
else:
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
if nonnegative_features:
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
):
return nonnegative_softmax_kernel_feature_creator(
data,
projection_matrix,
attention_dims_t,
batch_dims_t,
precision,
is_query,
normalize_data,
numerical_stabilizer,
)
else:
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
):
del is_query
return sincos_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data
)
attention_fn = FastAttentionviaLowRankDecomposition(
matrix_creator,
kernel_feature_creator,
renormalize_attention=renormalize_attention,
numerical_stabilizer=numerical_stabilizer,
redraw_features=redraw_features,
unidirectional=unidirectional,
lax_scan_unroll=lax_scan_unroll,
).dot_product_attention
return attention_fn
def make_fast_generalized_attention(
qkv_dim,
renormalize_attention=True,
numerical_stabilizer=0.0,
nb_features=256,
features_type="deterministic",
kernel_fn=jax.nn.relu,
kernel_epsilon=0.001,
redraw_features=False,
unidirectional=False,
lax_scan_unroll=1,
):
"""Construct a fast generalized attention menthod."""
logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention)
if features_type == "ortho":
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
elif features_type == "iid":
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
elif features_type == "deterministic":
matrix_creator = None
else:
raise ValueError("Unknown feature value type")
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False
):
del attention_dims_t
del is_query
return generalized_kernel_feature_creator(
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
)
attention_fn = FastAttentionviaLowRankDecomposition(
matrix_creator,
kernel_feature_creator,
renormalize_attention=renormalize_attention,
numerical_stabilizer=numerical_stabilizer,
redraw_features=redraw_features,
unidirectional=unidirectional,
lax_scan_unroll=lax_scan_unroll,
).dot_product_attention
return attention_fn
class RandomMatrix(object):
r"""
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
random arrays.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def get_2d_array(self):
raise NotImplementedError("Abstract method")
class GaussianUnstructuredRandomMatrix(RandomMatrix):
def __init__(self, nb_rows, nb_columns, key):
self.nb_rows = nb_rows
self.nb_columns = nb_columns
self.key = key
def get_2d_array(self):
return random.normal(self.key, (self.nb_rows, self.nb_columns))
class GaussianOrthogonalRandomMatrix(RandomMatrix):
r"""
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian
orthogonal arrays.
"""
def __init__(self, nb_rows, nb_columns, key, scaling=0):
self.nb_rows = nb_rows
self.nb_columns = nb_columns
self.key = key
self.scaling = scaling
def get_2d_array(self):
nb_full_blocks = int(self.nb_rows / self.nb_columns)
block_list = []
rng = self.key
for _ in range(nb_full_blocks):
rng, rng_input = jax.random.split(rng)
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
q, _ = jnp.linalg.qr(unstructured_block)
q = jnp.transpose(q)
block_list.append(q)
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
if remaining_rows > 0:
rng, rng_input = jax.random.split(rng)
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
q, _ = jnp.linalg.qr(unstructured_block)
q = jnp.transpose(q)
block_list.append(q[0:remaining_rows])
final_matrix = jnp.vstack(block_list)
if self.scaling == 0:
multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
elif self.scaling == 1:
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
else:
raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling)
return jnp.matmul(jnp.diag(multiplier), final_matrix)
class FastAttention(object):
r"""
Abstract class providing a method for fast attention. Class is responsible for providing a method
<dot_product_attention> for fast approximate attention.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def dot_product_attention(
self,
query,
key,
value,
dtype=jnp.float32,
bias=None,
axis=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
precision=None,
):
"""
Computes dot-product attention given query, key, and value. This is the core function for applying fast
approximate dot-product attention. It calculates the attention weights given query and key and combines the
values using the attention weights. This function supports multi-dimensional inputs
Args:
query: queries for calculating attention with shape of [batch_size, dim1,
dim2, ..., dimN, num_heads, mem_channels].
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
..., dimN, num_heads, mem_channels].
value: values to be used in attention with shape of [batch_size, dim1,
dim2,..., dimN, num_heads, value_channels].
dtype: the dtype of the computation (default: float32)
bias: bias for the attention weights. This can be used for incorporating
autoregressive mask, padding mask, proximity bias.
axis: axises over which the attention is applied.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout.
dropout_rate: dropout rate.
deterministic: bool, deterministic or not (to apply dropout).
precision: numerical precision of the computation see `jax.lax.Precision`
for details
Returns:
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
"""
raise NotImplementedError("Abstract method")
def _numerator(z_slice_shape, precision, unroll=1):
def fwd(qs, ks, vs):
def body(p, qkv):
(q, k, v) = qkv
p += jnp.einsum("...m,...d->...md", k, v, precision=precision)
X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision)
return p, X_slice
init_value = jnp.zeros(z_slice_shape)
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
return W, (p, qs, ks, vs)
def bwd(pqkv, W_ct):
def body(carry, qkv_xct):
p, p_ct = carry
q, k, v, x_ct = qkv_xct
q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision)
p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision)
k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision)
v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision)
p -= jnp.einsum("...m,...d->...md", k, v, precision=precision)
return (p, p_ct), (q_ct, k_ct, v_ct)
p, qs, ks, vs = pqkv
_, (qs_ct, ks_ct, vs_ct) = lax.scan(
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll
)
return qs_ct, ks_ct, vs_ct
@jax.custom_vjp
def _numerator_impl(qs, ks, vs):
W, _ = fwd(qs, ks, vs)
return W
_numerator_impl.defvjp(fwd, bwd)
return _numerator_impl
def _denominator(t_slice_shape, precision, unroll=1):
def fwd(qs, ks):
def body(p, qk):
q, k = qk
p += k
x = jnp.einsum("...m,...m->...", q, p, precision=precision)
return p, x
p = jnp.zeros(t_slice_shape)
p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
return R, (qs, ks, p)
def bwd(qkp, R_ct):
def body(carry, qkx):
p, p_ct = carry
q, k, x_ct = qkx
q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision)
p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision)
k_ct = p_ct
p -= k
return (p, p_ct), (q_ct, k_ct)
qs, ks, p = qkp
_, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll)
return (qs_ct, ks_ct)
@jax.custom_vjp
def _denominator_impl(qs, ks):
R, _ = fwd(qs, ks)
return R
_denominator_impl.defvjp(fwd, bwd)
return _denominator_impl
class FastAttentionviaLowRankDecomposition(FastAttention):
r"""
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random
feature maps).
"""
def __init__(
self,
matrix_creator,
kernel_feature_creator,
renormalize_attention,
numerical_stabilizer,
redraw_features,
unidirectional,
lax_scan_unroll=1,
): # For optimal GPU performance, set to 16.
rng = random.PRNGKey(0)
self.matrix_creator = matrix_creator
self.projection_matrix = self.draw_weights(rng)
self.kernel_feature_creator = kernel_feature_creator
self.renormalize_attention = renormalize_attention
self.numerical_stabilizer = numerical_stabilizer
self.redraw_features = redraw_features
self.unidirectional = unidirectional
self.lax_scan_unroll = lax_scan_unroll
def draw_weights(self, key):
if self.matrix_creator is None:
return None
matrixrng, _ = random.split(key)
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
return projection_matrix
def dot_product_attention(
self,
query,
key,
value,
dtype=jnp.float32,
bias=None,
axis=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
precision=None,
):
assert key.shape[:-1] == value.shape[:-1]
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
if axis is None:
axis = tuple(range(1, key.ndim - 2))
if not isinstance(axis, Iterable):
axis = (axis,)
assert key.ndim == query.ndim
assert key.ndim == value.ndim
for ax in axis:
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
n = key.ndim
# Constructing projection tensor.
if self.redraw_features:
# TODO(kchoro): Get rid of the constant below.
query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
rng = random.PRNGKey(query_seed)
self.projection_matrix = self.draw_weights(rng)
# batch_dims is <bs, <non-attention dims>, num_heads>
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
qk_perm = batch_dims + axis + (n - 1,)
k_extra_perm = axis + batch_dims + (n - 1,)
key_extra = key.transpose(k_extra_perm)
key = key.transpose(qk_perm)
query = query.transpose(qk_perm)
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
v_perm = batch_dims + axis + (n - 1,)
value = value.transpose(v_perm)
batch_dims_t = tuple(range(len(batch_dims)))
attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
# Constructing tensors Q^{'} and K^{'}.
query_prime = self.kernel_feature_creator(
query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True
)
key_prime = self.kernel_feature_creator(
key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False
)
if self.unidirectional:
index = attention_dims_t[0]
z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],)
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
W = numerator_fn(
jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)
)
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
W = jnp.moveaxis(W, 0, index)
if not self.renormalize_attention:
# Unidirectional, not-normalized attention.
perm_inv = _invert_perm(qk_perm)
result = W.transpose(perm_inv)
return result
else:
# Unidirectional, normalized attention.
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
index = attention_dims_t[0]
t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],)
denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll)
R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0))
R = jnp.moveaxis(R, 0, index)
else:
contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1))
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
# Constructing Z = (K^{'})^{T}V
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
Z = lax.dot_general(
key_prime,
value,
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
precision=precision,
)
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
W = lax.dot_general(
query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision
)
if not self.renormalize_attention:
# Bidirectional, not-normalized attention.
perm_inv = _invert_perm(qk_perm)
result = W.transpose(perm_inv)
return result
else:
# Bidirectional, normalized attention.
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
# Construct T = (K^{'})^{T} 1_L
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
T = lax.dot_general(
key_prime,
thick_all_ones,
((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)),
precision=precision,
)
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
# T (bs, <non-attention dims>, num_heads, channels_m)
R = lax.dot_general(
query_prime,
T,
(((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))),
precision=precision,
)
R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer)
R = jnp.reciprocal(R)
R = jnp.expand_dims(R, len(R.shape))
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
result = W * R
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
perm_inv = _invert_perm(qk_perm)
result = result.transpose(perm_inv)
return result
def _invert_perm(perm):
perm_inv = [0] * len(perm)
for i, j in enumerate(perm):
perm_inv[j] = i
return tuple(perm_inv)
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
import logging
import os
import sys
from dataclasses import dataclass, field
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.optim import Adam
from flax.training import common_utils
from flax.training.common_utils import get_metrics
from jax.nn import log_softmax
from modeling_flax_performer import FlaxPerformerForMaskedLM
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
BertConfig,
FlaxBertForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
TrainingArguments,
is_tensorboard_available,
set_seed,
)
# Cache the result
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
try:
from flax.metrics.tensorboard import SummaryWriter
except ImportError as ie:
has_tensorboard = False
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
else:
print(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class WandbArguments:
"""
Arguments for logging
"""
wandb_user_name: Optional[str] = field(
default=None,
metadata={"help": "The WandB user name for potential logging. If left None, no logging"},
)
wandb_project_name: Optional[str] = field(
default="performer-experiments",
metadata={"help": "The WandB project name for potential logging"},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
performer: bool = field(
default=False,
metadata={"help": "Whether to use FAVOR+ attention"},
)
reinitialize: bool = field(
default=False,
metadata={"help": "Whether to use a blank model without pretraining"},
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
# Adapted from transformers/data/data_collator.py
# Letting here for now, let's discuss where it should live
@dataclass
class FlaxDataCollatorForLanguageModeling:
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
non-masked tokens and the value to predict for the masked token.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
.. note::
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
"""
tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
# Handle dict or lists with proper padding and conversion to tensor.
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
if self.mlm:
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = batch["input_ids"].copy()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
def mask_tokens(
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.copy()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = np.full(labels.shape, self.mlm_probability)
special_tokens_mask = special_tokens_mask.astype("bool")
probability_matrix[special_tokens_mask] = 0.0
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
indices_random &= masked_indices & ~indices_replaced
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
def create_learning_rate_scheduler(
factors="constant * linear_warmup * rsqrt_decay",
base_learning_rate=0.5,
warmup_steps=1000,
decay_factor=0.5,
steps_per_decay=20000,
steps_per_cycle=100000,
):
"""Creates learning rate schedule.
Interprets factors in the factors string which can consist of:
* constant: interpreted as the constant value,
* linear_warmup: interpreted as linear warmup until warmup_steps,
* rsqrt_decay: divide by square root of max(step, warmup_steps)
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
* decay_every: Every k steps decay the learning rate by decay_factor.
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
Args:
factors: string, factors separated by "*" that defines the schedule.
base_learning_rate: float, the starting constant for the lr schedule.
warmup_steps: int, how many steps to warm up for in the warmup schedule.
decay_factor: float, the amount to decay the learning rate by.
steps_per_decay: int, how often to decay the learning rate.
steps_per_cycle: int, steps per cycle when using cosine decay.
Returns:
a function learning_rate(step): float -> {"learning_rate": float}, the
step-dependent lr.
"""
factors = [n.strip() for n in factors.split("*")]
def step_fn(step):
"""Step to learning rate function."""
ret = 1.0
for name in factors:
if name == "constant":
ret *= base_learning_rate
elif name == "linear_warmup":
ret *= jnp.minimum(1.0, step / warmup_steps)
elif name == "rsqrt_decay":
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == "rsqrt_normalized_decay":
ret *= jnp.sqrt(warmup_steps)
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
elif name == "decay_every":
ret *= decay_factor ** (step // steps_per_decay)
elif name == "cosine_decay":
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
else:
raise ValueError("Unknown factor %s." % name)
return jnp.asarray(ret, dtype=jnp.float32)
return step_fn
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
"""Compute summary metrics."""
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
acc, _ = accuracy(logits, labels, weights)
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
def accuracy(logits, targets, weights=None):
"""Compute weighted accuracy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch, length]
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError(
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
)
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
loss *= weights
return loss.sum(), weights.sum()
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
"""Compute cross entropy and entropy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch, length]
label_smoothing: label smoothing constant, used to determine the on and off values.
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError(
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
)
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
else:
normalizing_factor = np.prod(targets.shape)
return loss.sum(), normalizing_factor
def training_step(optimizer, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
targets = batch.pop("labels")
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, weight_sum = cross_entropy(logits, targets, token_mask)
return loss / weight_sum
step = optimizer.state.step
lr = lr_scheduler_fn(step)
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(optimizer.target)
grad = jax.lax.pmean(grad, "batch")
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return loss, optimizer, new_dropout_rng
def eval_step(params, batch):
"""
Calculate evaluation metrics on a batch.
"""
targets = batch.pop("labels")
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
logits = model(**batch, params=params, train=False)[0]
return compute_metrics(logits, targets, token_mask)
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
nb_samples = len(samples_idx)
samples_to_remove = nb_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = nb_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
if __name__ == "__main__":
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, WandbArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, wandb_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args, wandb_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level="NOTSET",
datefmt="[%X]",
)
# Log on each process the small summary:
logger = logging.getLogger(__name__)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
config = BertConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
lm_class = FlaxPerformerForMaskedLM if model_args.performer else FlaxBertForMaskedLM
if model_args.reinitialize:
model = lm_class(config=BertConfig.from_pretrained(model_args.model_name_or_path))
else:
model = lm_class.from_pretrained(
model_args.model_name_or_path,
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
padding = "max_length" if data_args.pad_to_max_length else False
def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
)
tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Enable tensorboard only on the master node
if has_tensorboard and jax.host_id() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
# Setup optimizer
optimizer = Adam(
learning_rate=training_args.learning_rate,
weight_decay=training_args.weight_decay,
beta1=training_args.adam_beta1,
beta2=training_args.adam_beta2,
).create(model.params)
# Create learning rate scheduler
lr_scheduler_fn = create_learning_rate_scheduler(
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
)
# Create parallel version of the training and evaluation steps
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the optimizer on each device
optimizer = jax_utils.replicate(optimizer)
# Store some constant
nb_epochs = int(training_args.num_train_epochs)
batch_size = int(training_args.train_batch_size)
eval_batch_size = int(training_args.eval_batch_size)
if wandb_args.wandb_user_name is not None:
import wandb
wandb.init(project=wandb_args.wandb_project_name, entity=wandb_args.wandb_user_name)
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
# Create sampling rng
rng, training_rng, eval_rng = jax.random.split(rng, 3)
# Generate an epoch by shuffling sampling indices from the train dataset
nb_training_samples = len(tokenized_datasets["train"])
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
# Gather the indexes for creating the batch and do a training step
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = common_utils.shard(model_inputs.data)
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
if wandb_args.wandb_user_name is not None:
wandb.log({"Training loss": np.array(loss).mean()})
epochs.write(f"Loss: {loss}")
# ======================== Evaluating ==============================
nb_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(nb_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = common_utils.shard(model_inputs.data)
metrics = p_eval_step(optimizer.target, model_inputs)
eval_metrics.append(metrics)
eval_metrics_np = get_metrics(eval_metrics)
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
eval_normalizer = eval_metrics_np.pop("normalizer")
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
# Update progress bar
epochs.desc = (
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
)
if wandb_args.wandb_user_name is not None:
wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()})
# Save metrics
if has_tensorboard and jax.host_id() == 0:
for name, value in eval_summary.items():
summary_writer.scalar(name, value, epoch)
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.simple --model_name_or_path bert-base-cased --tokenizer_name bert-base-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment